#include "AllShader.hpp"
const char* shader_MetalReLU6_metal = 
"kernel void relu6_x1(const device M *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" constant float4 &minMax [[buffer(2)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" out[int(gid)]=clamp(in[int(gid)],(M)(minMax.x),(M)(minMax.y));\n"
"}\n"
"kernel void relu6_x4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant float4 &minMax [[buffer(2)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" out[int(gid)]=clamp(in[int(gid)],(M4)minMax.x,(M4)minMax.y);\n"
"}\n"
;
const char* shader_MetalReLU_metal = 
"kernel void relu_x1(const device M *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" constant float &slope [[buffer(2)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" auto V=in[int(gid)];\n"
" out[int(gid)]=fmax(V,(M)0)+fmin(V,(M)0)*M(slope);\n"
"}\n"
"kernel void relu_x4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant float &slope [[buffer(2)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" auto V=in[int(gid)];\n"
" out[int(gid)]=fmax(V,(M4)0)+fmin(V,(M4)0)*M4(slope);\n"
"}\n"
;
const char* shader_MetalConvolutionDepthwise_metal = 
"struct conv_dw_cst {\n"
" int input_width;\n"
" int input_height;\n"
" int input_size;\n"
" int output_width;\n"
" int output_height;\n"
" int output_size;\n"
" int slice;\n"
" int batch;\n"
" \n"
" int kernel_x;\n"
" int kernel_y;\n"
" int kernel_size;\n"
" int stride_x;\n"
" int stride_y;\n"
" int pad_x;\n"
" int pad_y;\n"
" int dilation_x;\n"
" int dilation_y;\n"
" conv_activation_type activation;\n"
"};\n"
"kernel void conv_depthwise(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv_dw_cst& cst [[buffer(2)]],\n"
" const device M4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.slice*cst.batch) return;\n"
" \n"
" int oz=gid.z % cst.slice;\n"
" int offset_x=(int)gid.x*cst.stride_x-cst.pad_x;\n"
" int offset_y=(int)gid.y*cst.stride_y-cst.pad_y;\n"
" int sx=max(0,(UP_DIV(-offset_x,cst.dilation_x)));\n"
" int ex=min(cst.kernel_x,UP_DIV(cst.input_width-offset_x,cst.dilation_x));\n"
" int sy=max(0,(UP_DIV(-offset_y,cst.dilation_y)));\n"
" int ey=min(cst.kernel_y,UP_DIV(cst.input_height-offset_y,cst.dilation_y));\n"
" offset_x += sx*cst.dilation_x;\n"
" offset_y += sy*cst.dilation_y;\n"
" auto z_wt=wt+(int)oz*cst.kernel_size;\n"
" auto z_in=in+(int)gid.z*cst.input_size;\n"
" auto z_out=out+(int)gid.z*cst.output_size+(int)gid.y*cst.output_width+(int)gid.x;\n"
" FLOAT4 result=FLOAT4(biasTerms[oz]);\n"
" for (auto ky=sy,y=offset_y; ky<ey; ky++,y += cst.dilation_y) {\n"
" for (auto kx=sx,x=offset_x; kx<ex; kx++,x += cst.dilation_x) {\n"
" auto wt4=z_wt[ky*cst.kernel_x+kx];\n"
" auto in4=z_in[ y*cst.input_width+x];\n"
" result += FLOAT4(in4*wt4);\n"
" }\n"
" }\n"
" *z_out=activate((M4)result,cst.activation);\n"
"}\n"
;
const char* shader_MetalConvolutionActivation_metal = 
"typedef enum : int {\n"
" None=0,\n"
" ReLU=1,\n"
" ReLU6=2,\n"
"} conv_activation_type;\n"
"inline M4 activate(M4 V,conv_activation_type type) {\n"
" switch (type) {\n"
" case ReLU:\n"
" return max(V,(M4)0);\n"
" case ReLU6:\n"
" return clamp(V,(M4)0,(M4)6);\n"
" default: // None\n"
" return V;\n"
" }\n"
"}\n"
;
const char* shader_MetalConvolution_metal = 
"#define CONV_UNROLL (4)\n"
"#define CONV_MUL_PACK_W2(x,y) "
" x += FLOAT4(in00*k00);"
" y += FLOAT4(in01*k00);"
" x += FLOAT4(in01*k01);"
" y += FLOAT4(in02*k01);"
" x += FLOAT4(in02*k02);"
" y += FLOAT4(in03*k02);"
" "
" x += FLOAT4(in10*k10);"
" y += FLOAT4(in11*k10);"
" x += FLOAT4(in11*k11);"
" y += FLOAT4(in12*k11);"
" x += FLOAT4(in12*k12);"
" y += FLOAT4(in13*k12);"
" "
" x += FLOAT4(in20*k20);"
" y += FLOAT4(in21*k20);"
" x += FLOAT4(in21*k21);"
" y += FLOAT4(in22*k21);"
" x += FLOAT4(in22*k22);"
" y += FLOAT4(in23*k22);\n"
" \n"
"#define CONV_NEXT_FLT "
" z_wt += ws; "
" "
" k00=z_wt[0],k01=z_wt[1],k02=z_wt[2];"
" k10=z_wt[3],k11=z_wt[4],k12=z_wt[5];"
" k20=z_wt[6],k21=z_wt[7],k22=z_wt[8];\n"
"#define CONV_MUL_PACK_H2(x,y) "
" x += FLOAT4(in10*k00);"
" y += FLOAT4(in11*k00);"
" x += FLOAT4(in11*k01);"
" y += FLOAT4(in12*k01);"
" x += FLOAT4(in12*k02);"
" y += FLOAT4(in13*k02);"
" "
" x += FLOAT4(in20*k10);"
" y += FLOAT4(in21*k10);"
" x += FLOAT4(in21*k11);"
" y += FLOAT4(in22*k11);"
" x += FLOAT4(in22*k12);"
" y += FLOAT4(in23*k12);"
" "
" x += FLOAT4(in30*k20);"
" y += FLOAT4(in31*k20);"
" x += FLOAT4(in31*k21);"
" y += FLOAT4(in32*k21);"
" x += FLOAT4(in32*k22);"
" y += FLOAT4(in33*k22);\n"
"struct conv_constants {\n"
" int input_width;\n"
" int input_height;\n"
" int input_size;\n"
" int input_slice;\n"
" int output_width;\n"
" int output_height;\n"
" int output_size;\n"
" int output_slice;\n"
" int batch;\n"
" int oz_size;\n"
" int threadgroup_input_slice;\n"
" \n"
" int kernel_x;\n"
" int kernel_y;\n"
" int kernel_size;\n"
" int stride_x;\n"
" int stride_y;\n"
" int pad_x;\n"
" int pad_y;\n"
" int dilation_x;\n"
" int dilation_y;\n"
" conv_activation_type activation; \n"
"};\n"
"kernel void conv(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.oz_size) return;\n"
" \n"
" int idx_w=gid.x;\n"
" int idx_h=gid.y;\n"
" int idx_c=gid.z/cst.batch;\n"
" int idx_b=gid.z % cst.batch;\n"
" \n"
" int offset_x=(int)idx_w*cst.stride_x-cst.pad_x;\n"
" int offset_y=(int)idx_h*cst.stride_y-cst.pad_y;\n"
" int sx=max(0,(UP_DIV(-offset_x,cst.dilation_x)));\n"
" int ex=min(cst.kernel_x,UP_DIV(cst.input_width-offset_x,cst.dilation_x));\n"
" int kw=ex-sx;\n"
" int sy=max(0,(UP_DIV(-offset_y,cst.dilation_y)));\n"
" int ey=min(cst.kernel_y,UP_DIV(cst.input_height-offset_y,cst.dilation_y));\n"
" int kh=ey-sy;\n"
" offset_x += sx*cst.dilation_x;\n"
" offset_y += sy*cst.dilation_y;\n"
" \n"
" auto z_in=in+idx_b*cst.input_slice*cst.input_size+offset_y*cst.input_width+offset_x;\n"
" auto z_wt=wt+idx_c*cst.input_slice*cst.kernel_size+sy*cst.kernel_x+sx;\n"
" auto z_out=out+idx_b*cst.output_slice*cst.output_size+(int)idx_c*cst.output_size+(int)gid.y*cst.output_width+(int)gid.x;\n"
" int dilation_h=cst.input_width*cst.dilation_y;\n"
" FLOAT4 result=FLOAT4(biasTerms[idx_c]);\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" for (auto y=0; y<kh; y++) {\n"
" for (auto x=0; x<kw; x++) {\n"
" auto wt4=z_wt[z*cst.kernel_size+y*cst.kernel_x+x];\n"
" auto in4=z_in[z*cst.input_size+y*dilation_h+x*cst.dilation_x];\n"
" result += FLOAT4(in4*wt4);\n"
" }\n"
" }\n"
" }\n"
" *z_out=activate(M4(result),cst.activation);\n"
"}\n"
"kernel void convk3s1d1p1_w2z4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*2 >= cst.output_width || (int)gid.y >= cst.output_height) return;\n"
" \n"
" int idx_w=gid.x << 1;\n"
" int idx_h=gid.y;\n"
" int idx_c=gid.z/cst.batch;\n"
" int idx_b=gid.z % cst.batch;\n"
" int4 uz=idx_c*CONV_UNROLL+int4(0,1,2,3);\n"
" bool3 valids=uz.yzw<cst.output_slice;\n"
" bool valid_x=(int)(gid.x*2+1)<cst.output_width;\n"
" int offset_x=(int)gid.x*2-cst.pad_x;\n"
" int offset_y=(int)gid.y-cst.pad_y;\n"
" auto z_in=in+idx_b*cst.input_slice*cst.input_size+offset_y*cst.input_width+offset_x;\n"
" auto z_flt=wt+uz[0]*cst.input_slice*cst.kernel_size;\n"
" auto z_out=out+idx_b*cst.output_slice*cst.output_size+uz[0]*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" \n"
" int ws=cst.input_slice*cst.kernel_size;\n"
" FLOAT4 result0=0,result1=0,result2=0,result3=0;\n"
" FLOAT4 result4=0,result5=0,result6=0,result7=0;\n"
" for (auto z=0; z<cst.input_slice; z++,z_flt += cst.kernel_size,z_in += cst.input_size) {\n"
" auto in00=(offset_x<0 || offset_y<0) ? (M4)0.f : *(z_in+0*cst.input_width+0);\n"
" auto in01=(offset_x+1>=cst.input_width || offset_y<0) ? (M4)0.f : *(z_in+0*cst.input_width+1);\n"
" auto in02=(offset_x+2>=cst.input_width || offset_y<0) ? (M4)0.f : *(z_in+0*cst.input_width+2);\n"
" auto in03=(offset_x+3>=cst.input_width || offset_y<0) ? (M4)0.f : *(z_in+0*cst.input_width+3);\n"
" auto in10=(offset_x<0 || offset_y+1>=cst.input_height) ? (M4)0.f : *(z_in+1*cst.input_width+0);\n"
" auto in11=(offset_x+1>=cst.input_width || offset_y+1>=cst.input_height) ? (M4)0.f : *(z_in+1*cst.input_width+1);\n"
" auto in12=(offset_x+2>=cst.input_width || offset_y+1>=cst.input_height) ? (M4)0.f : *(z_in+1*cst.input_width+2);\n"
" auto in13=(offset_x+3>=cst.input_width || offset_y+1>=cst.input_height) ? (M4)0.f : *(z_in+1*cst.input_width+3);\n"
" \n"
" auto in20=(offset_x<0 || offset_y+2>=cst.input_height) ? (M4)0.f : *(z_in+2*cst.input_width+0);\n"
" auto in21=(offset_x+1>=cst.input_width || offset_y+2>=cst.input_height) ? (M4)0.f : *(z_in+2*cst.input_width+1);\n"
" auto in22=(offset_x+2>=cst.input_width || offset_y+2>=cst.input_height) ? (M4)0.f : *(z_in+2*cst.input_width+2);\n"
" auto in23=(offset_x+3>=cst.input_width || offset_y+2>=cst.input_height) ? (M4)0.f : *(z_in+2*cst.input_width+3);\n"
" \n"
" auto z_wt=z_flt;\n"
" auto k00=z_wt[0],k01=z_wt[1],k02=z_wt[2];\n"
" auto k10=z_wt[3],k11=z_wt[4],k12=z_wt[5];\n"
" auto k20=z_wt[6],k21=z_wt[7],k22=z_wt[8];\n"
" CONV_MUL_PACK_W2(result0,result4);\n"
" if (valids[0]) {\n"
" CONV_NEXT_FLT;\n"
" CONV_MUL_PACK_W2(result1,result5);\n"
" }\n"
" if (valids[1]) {\n"
" CONV_NEXT_FLT;\n"
" CONV_MUL_PACK_W2(result2,result6);\n"
" }\n"
" if (valids[2]) {\n"
" CONV_NEXT_FLT;\n"
" CONV_MUL_PACK_W2(result3,result7);\n"
" }\n"
" }\n"
" /* true */ *z_out=activate(M4(result0+float4(biasTerms[uz[0]])),cst.activation);\n"
" if(valid_x) {\n"
" *(z_out+1)=activate(M4(result4+float4(biasTerms[uz[0]])),cst.activation);\n"
" }\n"
" if (valids[0]) {\n"
" z_out += cst.output_size;\n"
" *z_out=activate(M4(result1+float4(biasTerms[uz[1]])),cst.activation);\n"
" if(valid_x) {\n"
" *(z_out+1)=activate(M4(result5+float4(biasTerms[uz[1]])),cst.activation);\n"
" }\n"
" }\n"
" if (valids[1]) {\n"
" z_out += cst.output_size;\n"
" *z_out=activate(M4(result2+float4(biasTerms[uz[2]])),cst.activation);\n"
" if(valid_x) {\n"
" *(z_out+1)=activate(M4(result6+float4(biasTerms[uz[2]])),cst.activation);\n"
" }\n"
" }\n"
" if (valids[2]) {\n"
" z_out += cst.output_size;\n"
" *z_out=activate(M4(result3+float4(biasTerms[uz[3]])),cst.activation);\n"
" if(valid_x) {\n"
" *(z_out+1)=activate(M4(result7+float4(biasTerms[uz[3]])),cst.activation);\n"
" }\n"
" }\n"
"}\n"
"kernel void conv_z4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height) return;\n"
" \n"
" int idx_w=gid.x;\n"
" int idx_h=gid.y;\n"
" int idx_c=gid.z/cst.batch;\n"
" int idx_b=gid.z % cst.batch;\n"
" \n"
" int4 uz=idx_c*CONV_UNROLL+int4(0,1,2,3);\n"
" bool3 valids=uz.yzw<cst.output_slice;\n"
" \n"
" int offset_x=idx_w*cst.stride_x-cst.pad_x;\n"
" int offset_y=idx_h*cst.stride_y-cst.pad_y;\n"
" int sx=max(0,(UP_DIV(-offset_x,cst.dilation_x)));\n"
" int ex=min(cst.kernel_x,UP_DIV(cst.input_width-offset_x,cst.dilation_x));\n"
" int kw=ex-sx;\n"
" int sy=max(0,(UP_DIV(-offset_y,cst.dilation_y)));\n"
" int ey=min(cst.kernel_y,UP_DIV(cst.input_height-offset_y,cst.dilation_y));\n"
" int kh=ey-sy;\n"
" offset_x += sx*cst.dilation_x;\n"
" offset_y += sy*cst.dilation_y;\n"
" \n"
" auto z_in=in+idx_b*cst.input_slice*cst.input_size+offset_y*cst.input_width+offset_x;\n"
" auto z_wt=wt+uz[0]*cst.input_slice*cst.kernel_size+sy*cst.kernel_x+sx;\n"
" auto z_out=out+idx_b*cst.output_slice*cst.output_size+uz[0]*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" \n"
" int ws=cst.input_slice*cst.kernel_size;\n"
" int dilation_h=cst.input_width*cst.dilation_y;\n"
" FLOAT4 result0=0,result1=0,result2=0,result3=0;\n"
" for (auto z=0; z<cst.input_slice; z++,z_wt += cst.kernel_size,z_in += cst.input_size) {\n"
" for (auto y=0; y<kh; y++) {\n"
" for (auto x=0; x<kw; x++) {\n"
" auto x_wt=z_wt+y*cst.kernel_x+x;\n"
" auto in4=z_in[ y*dilation_h+x*cst.dilation_x];\n"
" /* true */ result0 += FLOAT4(in4**x_wt);\n"
" if (valids[0]) { x_wt += ws; result1 += FLOAT4(in4**x_wt); }\n"
" if (valids[1]) { x_wt += ws; result2 += FLOAT4(in4**x_wt); }\n"
" if (valids[2]) { x_wt += ws; result3 += FLOAT4(in4**x_wt); }\n"
" }\n"
" }\n"
" }\n"
" /* true */ *z_out=activate(M4(result0+FLOAT4(biasTerms[uz[0]])),cst.activation);\n"
" if (valids[0]) { z_out += cst.output_size; *z_out=activate(M4(result1+FLOAT4(biasTerms[uz[1]])),cst.activation); }\n"
" if (valids[1]) { z_out += cst.output_size; *z_out=activate(M4(result2+FLOAT4(biasTerms[uz[2]])),cst.activation); }\n"
" if (valids[2]) { z_out += cst.output_size; *z_out=activate(M4(result3+FLOAT4(biasTerms[uz[3]])),cst.activation); }\n"
"}\n"
;
const char* shader_MetalGridSample_metal = 
"struct grid_sample_params {\n"
" int batches;\n"
" int channels;\n"
" int inH;\n"
" int inW;\n"
" int outH;\n"
" int outW;\n"
" int mode; // 0-Bilinear,1-Nearest\n"
" int paddingMode; // 0-Zeros,1-Border,2-Reflection\n"
" int alignCorners;\n"
"};\n"
"static float getPosition(float x,int range,int alignCorners,int paddingMode) {\n"
" if (paddingMode == 2/*GridSamplePaddingMode_REFLECTION*/) {\n"
" // if x is on the left side of -1.0,move it to the right side of 1.0\n"
" if (x<-1.0f) {\n"
" x=x+::ceil(1-x)*4;\n"
" }\n"
" // reflect\n"
" if (x>1.0f) {\n"
" float l=x-1.0f;\n"
" int reflectionNum=::floor(l/2.0);\n"
" float offset=l-reflectionNum*2.0f;\n"
" x=(reflectionNum % 2 == 0) ? (1-offset) : (-1.0f+offset);\n"
" }\n"
" }\n"
" float a=alignCorners ? 1.0f : 0.0f;\n"
" float b=alignCorners ? 0.0f : 1.0f;\n"
" return ((1+x)*(range-a)-b)/2.0f;\n"
"}\n"
"static int CLAMP(int v,int min,int max) {\n"
" if ((v)<min) {\n"
" (v)=min;\n"
" } else if ((v)>max) {\n"
" (v)=max;\n"
" }\n"
" return v;\n"
"}\n"
"static M4 sample(int h,int w,const device M4 *buffer,int height,int width,int paddingMode) {\n"
" if (h<0 || h >= height || w<0 || w >= width) {\n"
" if (paddingMode == 0/*GridSamplePaddingMode_ZEROS*/) {\n"
" return 0.0f;\n"
" }\n"
" // Clearly,CLAMP is the right way to go for GridSamplePaddingMode_BORDER\n"
" // For GridSamplePaddingMode_REFLECTION,since we have reflected the Vs into (-1,1),\n"
" // the leftover reflections degrade to GridSamplePaddingMode_BORDER\n"
" h=CLAMP(h,0,height-1);\n"
" w=CLAMP(w,0,width-1);\n"
" }\n"
" return buffer[h*width+w];\n"
"}\n"
"static M4 interpolate(float h,float w,const device M4 *buffer,int height,int width,int mode,\n"
" int paddingMode) {\n"
" if (mode == 1/*GridSampleMode_NEAREST*/) {\n"
" int nh=::floor(h+0.5f);\n"
" int nw=::floor(w+0.5f);\n"
" return sample(nh,nw,buffer,height,width,paddingMode);\n"
" }\n"
" // mode == GridSampleMode_BILINEAR\n"
" int w0_h=::floor(h);\n"
" int w0_w=::floor(w);\n"
" int w1_h=w0_h+1;\n"
" int w1_w=w0_w+1;\n"
" M4 oneV=(M4)((M)1.0f);\n"
" M4 i00=sample(w0_h,w0_w,buffer,height,width,paddingMode);\n"
" M4 i01=sample(w0_h,w1_w,buffer,height,width,paddingMode);\n"
" M4 i10=sample(w1_h,w0_w,buffer,height,width,paddingMode);\n"
" M4 i11=sample(w1_h,w1_w,buffer,height,width,paddingMode);\n"
" \n"
" M4 f0=(M4)((M)(w1_w-w));\n"
" M4 f1=oneV-f0;\n"
" M4 h0=(M4)((M)(w1_h-h));\n"
" M4 h1=oneV-h0;\n"
" M4 i0=i00*f0+i01*f1;\n"
" M4 i1=i10*f0+i11*f1;\n"
" return i0*h0+i1*h1;\n"
"}\n"
"kernel void grid_sample(const device M4 *input [[buffer(0)]],\n"
" const device M *grid [[buffer(1)]],\n"
" device M4 *output [[buffer(2)]],\n"
" constant grid_sample_params &p [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= p.outW || (int)gid.y >= p.outH || (int)gid.z >= p.batches)\n"
" return;\n"
" int gridPos=gid.z*p.outH*p.outW*2+gid.y*p.outW*2+gid.x*2;\n"
" auto x=getPosition(grid[gridPos+0],p.inW,p.alignCorners,p.paddingMode);\n"
" auto y=getPosition(grid[gridPos+1],p.inH,p.alignCorners,p.paddingMode);\n"
" \n"
" const int channelC4=(p.channels+3)/4;\n"
" for (int c=0; c<channelC4; ++ c) {\n"
" auto outputPos=gid.z*channelC4*p.outH*p.outW+c*p.outH*p.outW+gid.y*p.outW+gid.x;\n"
" auto inputPtr=input+gid.z*channelC4*p.inH*p.inW+c*p.inH*p.inW;\n"
" output[outputPos]=interpolate(y,x,inputPtr,p.inH,p.inW,p.mode,p.paddingMode);\n"
" }\n"
"}\n"
;
const char* shader_MetalReduction_metal = 
"struct reduce_shape {\n"
" int outside_size;\n"
" int axis_size;\n"
" int inside_size;\n"
" int outside_step;\n"
"};\n"
"template <typename M,typename T>\n"
"static inline void reduce_mean(const device T *in,device T *out,constant reduce_shape &s,uint2 gid) {\n"
" auto axis_in=in+gid.x*s.outside_step+gid.y;\n"
" M summer=0;\n"
" for (int i=0; i<s.axis_size; i++,axis_in += s.inside_size) {\n"
" summer += M(*axis_in);\n"
" }\n"
" out[int(gid.x)*s.inside_size+int(gid.y)]=T(summer/s.axis_size);\n"
"}\n"
"template <typename M,typename T>\n"
"static inline void reduce_sum(const device T *in,device T *out,constant reduce_shape &s,uint2 gid) {\n"
" auto axis_in=in+gid.x*s.outside_step+gid.y;\n"
" M summer=0;\n"
" for (int i=0; i<s.axis_size; i++,axis_in += s.inside_size) {\n"
" summer += M(*axis_in);\n"
" }\n"
" out[int(gid.x)*s.inside_size+int(gid.y)]=T(summer);\n"
"}\n"
"template <typename M,typename T>\n"
"static inline void reduce_min(const device T *in,device T *out,constant reduce_shape &s,uint2 gid) {\n"
" auto axis_in=in+gid.x*s.outside_step+gid.y;\n"
" T summer=*axis_in; axis_in += s.inside_size;\n"
" for (int i=1; i<s.axis_size; i++,axis_in += s.inside_size) {\n"
" summer=min(summer,*axis_in);\n"
" }\n"
" out[int(gid.x)*s.inside_size+int(gid.y)]=summer;\n"
"}\n"
"template <typename M,typename T>\n"
"static inline void reduce_max(const device T *in,device T *out,constant reduce_shape &s,uint2 gid) {\n"
" auto axis_in=in+gid.x*s.outside_step+gid.y;\n"
" T summer=*axis_in; axis_in += s.inside_size;\n"
" for (int i=1; i<s.axis_size; i++,axis_in += s.inside_size) {\n"
" summer=max(summer,*axis_in);\n"
" }\n"
" out[int(gid.x)*s.inside_size+int(gid.y)]=summer;\n"
"}\n"
"template <typename M,typename T>\n"
"static inline void reduce_prod(const device T *in,device T *out,constant reduce_shape &s,uint2 gid) {\n"
" auto axis_in=in+gid.x*s.outside_step+gid.y;\n"
" M summer=1;\n"
" for (int i=0; i<s.axis_size; i++,axis_in += s.inside_size) {\n"
" summer *= M(*axis_in);\n"
" }\n"
" out[int(gid.x)*s.inside_size+int(gid.y)]=T(summer);\n"
"}\n"
"#define define_reduce(name) "
"kernel void reduce_##name##_f(const device M *in [[buffer(0)]],"
" device M *out [[buffer(1)]],"
" constant reduce_shape &s [[buffer(2)]],"
" uint2 gid [[thread_position_in_grid]]) { "
" if (gid.x<(uint)s.outside_size && gid.y<(uint)s.inside_size) reduce_##name<FLOAT,M>(in,out,s,gid); "
"} "
"kernel void reduce_##name##_s(const device int *in [[buffer(0)]],"
" device int *out [[buffer(1)]],"
" constant reduce_shape &s [[buffer(2)]],"
" uint2 gid [[thread_position_in_grid]]) { "
" if (gid.x<(uint)s.outside_size && gid.y<(uint)s.inside_size) reduce_##name<int,int>(in,out,s,gid); "
"}\n"
"define_reduce(mean);\n"
"define_reduce(sum);\n"
"define_reduce(min);\n"
"define_reduce(max);\n"
"define_reduce(prod);\n"
;
const char* shader_MetalBackend_metal = 
"struct tensor_shape {\n"
" int size;\n"
" int channel;\n"
" int slice;\n"
" int batch_slices;\n"
"};\n"
"kernel void version_func_002(const device uchar *in [[buffer(0)]],\n"
" device uchar *out [[buffer(1)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" // do nothing,just for verifying match between mnn and metallib\n"
"}\n"
"kernel void copy_byte(const device uchar *in [[buffer(0)]],\n"
" device uchar *out [[buffer(1)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" out[int(gid)]=in[int(gid)];\n"
"}\n"
"kernel void copy_int(const device int *in [[buffer(0)]],\n"
" device int *out [[buffer(1)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" out[int(gid)]=in[int(gid)];\n"
"}\n"
"kernel void copy_float(const device M *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" out[int(gid)]=in[int(gid)];\n"
"}\n"
"kernel void upcast_float(const device M *in [[buffer(0)]],\n"
" device float *out [[buffer(1)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" out[int(gid)]=in[int(gid)];\n"
"}\n"
"kernel void downcast_float(const device float *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" out[int(gid)]=in[int(gid)];\n"
"}\n"
"struct Limit {\n"
" uint4 size;\n"
"};\n"
"kernel void upcast_float4(const device M4 *in [[buffer(0)]],\n"
" device float4 *out [[buffer(1)]],\n"
" constant Limit& limit [[buffer(2)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" if (gid<limit.size.x) {\n"
" out[int(gid)]=float4(in[int(gid)]);\n"
" }\n"
"}\n"
"kernel void downcast_float4(const device float4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant Limit& limit [[buffer(2)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" if (gid<limit.size.x) {\n"
" out[int(gid)]=M4(in[int(gid)]);\n"
" }\n"
"}\n"
"template <typename IType,typename OType>\n"
"static inline void template_NHWC_to_NC4HW4(const device IType *in,device OType *out,constant tensor_shape &s,uint2 gid) {\n"
" int b=gid.y/s.slice;\n"
" int z=gid.y % s.slice;\n"
" int c=z*4;\n"
" \n"
" auto off_in=in+b*s.size*s.channel+int(gid.x)*s.channel+c;\n"
" auto off_out=out+int(gid.y)*s.size+int(gid.x);\n"
" off_out[0]=OType(c+0<s.channel ? off_in[0] : 0,\n"
" c+1<s.channel ? off_in[1] : 0,\n"
" c+2<s.channel ? off_in[2] : 0,\n"
" c+3<s.channel ? off_in[3] : 0);\n"
"}\n"
"kernel void upcast_f_NHWC_to_NC4HW4(const device M *in [[buffer(0)]],\n"
" device float4 *out [[buffer(1)]],\n"
" constant tensor_shape &s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.size && (int)gid.y<s.batch_slices) template_NHWC_to_NC4HW4<M,float4>(in,out,s,gid);\n"
"}\n"
"kernel void downcast_f_NHWC_to_NC4HW4(const device float *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant tensor_shape &s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.size && (int)gid.y<s.batch_slices) template_NHWC_to_NC4HW4<float,M4>(in,out,s,gid);\n"
"}\n"
"kernel void cvt_u_NHWC_to_NC4HW4(const device uchar *in [[buffer(0)]],\n"
" device uchar4 *out [[buffer(1)]],\n"
" constant tensor_shape &s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.size && (int)gid.y<s.batch_slices) template_NHWC_to_NC4HW4<uchar,uchar4>(in,out,s,gid);\n"
"}\n"
"kernel void cvt_f_NHWC_to_NC4HW4(const device M *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant tensor_shape &s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.size && (int)gid.y<s.batch_slices) template_NHWC_to_NC4HW4<M,M4>(in,out,s,gid);\n"
"}\n"
"template <typename IType,typename OType>\n"
"static inline void template_NC4HW4_to_NHWC(const device IType *in,device OType *out,constant tensor_shape &s,uint2 gid) {\n"
" int b=gid.y/s.slice;\n"
" int z=gid.y % s.slice;\n"
" int c=z*4;\n"
" auto off_in=in+int(gid.y)*s.size+int(gid.x);\n"
" auto off_out=out+b*s.size*s.channel+int(gid.x)*s.channel+c;\n"
" \n"
" IType v4=off_in[0];\n"
" /* if (1) */ off_out[0]=v4[0];\n"
" if (c+1<s.channel) off_out[1]=v4[1];\n"
" if (c+2<s.channel) off_out[2]=v4[2];\n"
" if (c+3<s.channel) off_out[3]=v4[3];\n"
"}\n"
"kernel void upcast_f_NC4HW4_to_NHWC(const device M4 *in [[buffer(0)]],\n"
" device float *out [[buffer(1)]],\n"
" constant tensor_shape &s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.size && (int)gid.y<s.batch_slices) template_NC4HW4_to_NHWC<M4,float>(in,out,s,gid);\n"
"}\n"
"kernel void downcast_f_NC4HW4_to_NHWC(const device float4 *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" constant tensor_shape &s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.size && (int)gid.y<s.batch_slices) template_NC4HW4_to_NHWC<float4,M>(in,out,s,gid);\n"
"}\n"
"kernel void cvt_u_NC4HW4_to_NHWC(const device uchar4 *in [[buffer(0)]],\n"
" device uchar *out [[buffer(1)]],\n"
" constant tensor_shape &s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.size && (int)gid.y<s.batch_slices) template_NC4HW4_to_NHWC<uchar4,uchar>(in,out,s,gid);\n"
"}\n"
"kernel void cvt_f_NC4HW4_to_NHWC(const device M4 *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" constant tensor_shape &s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.size && (int)gid.y<s.batch_slices) template_NC4HW4_to_NHWC<M4,M>(in,out,s,gid);\n"
"}\n"
"template <typename IType,typename OType>\n"
"static inline void template_NCHW_to_NC4HW4(const device IType *in,device OType *out,constant tensor_shape &s,uint2 gid) {\n"
" int b=gid.y/s.slice;\n"
" int z=gid.y % s.slice;\n"
" int c=z*4;\n"
" \n"
" auto off_in=in+(b*s.channel+c)*s.size+int(gid.x);\n"
" auto off_out=out+int(gid.y)*s.size+int(gid.x);\n"
" off_out[0]=OType(c+0<s.channel ? off_in[0*s.size] : 0.0h,\n"
" c+1<s.channel ? off_in[1*s.size] : 0.0h,\n"
" c+2<s.channel ? off_in[2*s.size] : 0.0h,\n"
" c+3<s.channel ? off_in[3*s.size] : 0.0h);\n"
"}\n"
"kernel void upcast_f_NCHW_to_NC4HW4(const device M *in [[buffer(0)]],\n"
" device float4 *out [[buffer(1)]],\n"
" constant tensor_shape &s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.size && (int)gid.y<s.batch_slices) template_NCHW_to_NC4HW4<M,float4>(in,out,s,gid);\n"
"}\n"
"kernel void downcast_f_NCHW_to_NC4HW4(const device float *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant tensor_shape &s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.size && (int)gid.y<s.batch_slices) template_NCHW_to_NC4HW4<float,M4>(in,out,s,gid);\n"
"}\n"
"kernel void cvt_u_NCHW_to_NC4HW4(const device uchar *in [[buffer(0)]],\n"
" device uchar4 *out [[buffer(1)]],\n"
" constant tensor_shape &s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.size && (int)gid.y<s.batch_slices) template_NCHW_to_NC4HW4<uchar,uchar4>(in,out,s,gid);\n"
"}\n"
"kernel void cvt_f_NCHW_to_NC4HW4(const device M *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant tensor_shape &s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.size && (int)gid.y<s.batch_slices) template_NCHW_to_NC4HW4<M,M4>(in,out,s,gid);\n"
"}\n"
"template <typename IType,typename OType>\n"
"static inline void template_NC4HW4_to_NCHW(const device IType *in,device OType *out,constant tensor_shape &s,uint2 gid) {\n"
" int b=gid.y/s.slice;\n"
" int z=gid.y % s.slice;\n"
" int c=z*4;\n"
" \n"
" auto off_in=in+int(gid.y)*s.size+int(gid.x);\n"
" auto off_out=out+(b*s.channel+c)*s.size+int(gid.x);\n"
" IType v4=off_in[0];\n"
" /* if (1) */ off_out[0*s.size]=v4.x;\n"
" if (c+1<s.channel) off_out[1*s.size]=v4.y;\n"
" if (c+2<s.channel) off_out[2*s.size]=v4.z;\n"
" if (c+3<s.channel) off_out[3*s.size]=v4.w;\n"
"}\n"
"kernel void upcast_f_NC4HW4_to_NCHW(const device M4 *in [[buffer(0)]],\n"
" device float *out [[buffer(1)]],\n"
" constant tensor_shape &s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.size && (int)gid.y<s.batch_slices) template_NC4HW4_to_NCHW<M4,float>(in,out,s,gid);\n"
"}\n"
"kernel void downcast_f_NC4HW4_to_NCHW(const device float4 *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" constant tensor_shape &s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.size && (int)gid.y<s.batch_slices) template_NC4HW4_to_NCHW<float4,M>(in,out,s,gid);\n"
"}\n"
"kernel void cvt_u_NC4HW4_to_NCHW(const device uchar4 *in [[buffer(0)]],\n"
" device uchar *out [[buffer(1)]],\n"
" constant tensor_shape &s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.size && (int)gid.y<s.batch_slices) template_NC4HW4_to_NCHW<uchar4,uchar>(in,out,s,gid);\n"
"}\n"
"kernel void cvt_f_NC4HW4_to_NCHW(const device M4 *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" constant tensor_shape &s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.size && (int)gid.y<s.batch_slices) template_NC4HW4_to_NCHW<M4,M>(in,out,s,gid);\n"
"}\n"
"struct SamplerInfo {\n"
" uint4 stride;//stride[3]+offset\n"
" uint4 size;//size[3]+totalSize\n"
" uint4 extent;//dstStride[3]+dstOffset\n"
" uint4 imageSize;\n"
"};\n"
"kernel void blit_intx4(const device int4 *in [[buffer(0)]],\n"
" device int4 *out [[buffer(1)]],\n"
" constant SamplerInfo &info [[buffer(2)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if (gid.x<info.size.x && gid.y<info.size.y && gid.z<info.size.z) {\n"
" uint dstOffset=gid.x*info.extent.x+gid.y*info.extent.y+gid.z*info.extent.z+info.extent.w;\n"
" uint srcOffset=gid.x*info.stride.x+gid.y*info.stride.y+gid.z*info.stride.z+info.stride.w;\n"
" out[int(dstOffset)]=in[int(srcOffset)];\n"
" }\n"
"}\n"
"kernel void blit_int(const device int *in [[buffer(0)]],\n"
" device int *out [[buffer(1)]],\n"
" constant SamplerInfo &info [[buffer(2)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if (gid.x<info.size.x && gid.y<info.size.y && gid.z<info.size.z) {\n"
" uint dstOffset=gid.x*info.extent.x+gid.y*info.extent.y+gid.z*info.extent.z+info.extent.w;\n"
" uint srcOffset=gid.x*info.stride.x+gid.y*info.stride.y+gid.z*info.stride.z+info.stride.w;\n"
" out[int(dstOffset)]=in[int(srcOffset)];\n"
" }\n"
"}\n"
"kernel void blit_int8(const device char *in [[buffer(0)]],\n"
" device char *out [[buffer(1)]],\n"
" constant SamplerInfo &info [[buffer(2)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if (gid.x<info.size.x && gid.y<info.size.y && gid.z<info.size.z) {\n"
" uint dstOffset=gid.x*info.extent.x+gid.y*info.extent.y+gid.z*info.extent.z+info.extent.w;\n"
" uint srcOffset=gid.x*info.stride.x+gid.y*info.stride.y+gid.z*info.stride.z+info.stride.w;\n"
" out[int(dstOffset)]=in[int(srcOffset)];\n"
" }\n"
"}\n"
"kernel void blit_int16(const device short *in [[buffer(0)]],\n"
" device short *out [[buffer(1)]],\n"
" constant SamplerInfo &info [[buffer(2)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if (gid.x<info.size.x && gid.y<info.size.y && gid.z<info.size.z) {\n"
" uint dstOffset=gid.x*info.extent.x+gid.y*info.extent.y+gid.z*info.extent.z+info.extent.w;\n"
" uint srcOffset=gid.x*info.stride.x+gid.y*info.stride.y+gid.z*info.stride.z+info.stride.w;\n"
" out[int(dstOffset)]=in[int(srcOffset)];\n"
" }\n"
"}\n"
"kernel void blit_int64(const device short4 *in [[buffer(0)]],\n"
" device short4 *out [[buffer(1)]],\n"
" constant SamplerInfo &info [[buffer(2)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if (gid.x<info.size.x && gid.y<info.size.y && gid.z<info.size.z) {\n"
" uint dstOffset=gid.x*info.extent.x+gid.y*info.extent.y+gid.z*info.extent.z+info.extent.w;\n"
" uint srcOffset=gid.x*info.stride.x+gid.y*info.stride.y+gid.z*info.stride.z+info.stride.w;\n"
" out[int(dstOffset)]=in[int(srcOffset)];\n"
" }\n"
"}\n"
"template<typename IType,typename OType>\n"
"static inline void template_NHWC_to_NCHW(const device IType* in,\n"
" device OType* out,constant tensor_shape &s,uint2 gid) {\n"
" int b=gid.y/s.slice;\n"
" int c4=gid.y % s.slice;\n"
" \n"
" auto in_off=(b*s.size+gid.x)*s.channel+c4*4;\n"
" auto out_off=(b*s.channel+c4*4)*s.size+gid.x;\n"
" \n"
" out[out_off]=in[in_off];\n"
" if(c4*4+1<s.channel) {\n"
" out[out_off+s.size]=in[in_off+1];\n"
" }\n"
" if(c4*4+2<s.channel) {\n"
" out[out_off+s.size*2]=in[in_off+2];\n"
" }\n"
" if(c4*4+3<s.channel) {\n"
" out[out_off+s.size*3]=in[in_off+3];\n"
" }\n"
"}\n"
"kernel void upcast_f_NHWC_to_NCHW(const device M *in [[buffer(0)]],\n"
" device float *out [[buffer(1)]],\n"
" constant tensor_shape &s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.size && (int)gid.y<s.batch_slices) template_NHWC_to_NCHW<M,float>(in,out,s,gid);\n"
"}\n"
"template<typename IType,typename OType>\n"
"static inline void template_NCHW_to_NHWC(const device IType* in,\n"
" device OType* out,constant tensor_shape &s,uint2 gid) {\n"
" int b=gid.y/s.slice;\n"
" int c4=gid.y % s.slice;\n"
" \n"
" auto in_off=(b*s.channel+c4*4)*s.size+gid.x;\n"
" auto out_off=(b*s.size+gid.x)*s.channel+c4*4;\n"
" \n"
" out[out_off]=in[in_off];\n"
" if(c4*4+1<s.channel) {\n"
" out[out_off+1]=in[in_off+s.size];\n"
" }\n"
" if(c4*4+2<s.channel) {\n"
" out[out_off+2]=in[in_off+s.size*2];\n"
" }\n"
" if(c4*4+3<s.channel) {\n"
" out[out_off+3]=in[in_off+s.size*3];\n"
" }\n"
"}\n"
"kernel void downcast_f_NCHW_to_NHWC(const device float *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" constant tensor_shape &s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.size && (int)gid.y<s.batch_slices) template_NCHW_to_NHWC<float,M>(in,out,s,gid);\n"
"}\n"
;
const char* shader_MetalSoftmax_metal = 
"struct softmax_shape {\n"
" int inside_size;\n"
" int axis_length;\n"
" int outside_size;\n"
" int flat_length;\n"
"};\n"
"static inline float softmax_max4(float4 V) {\n"
" return max(max(V[0],V[1]),max(V[2],V[3]));\n"
"}\n"
"static inline float softmax_sum4(float4 V) {\n"
" return V[0]+V[1]+V[2]+V[3];\n"
"}\n"
"static inline float4 softmax_filter(float4 V,int z,int limit) {\n"
" return select(0,V,z*4+int4(0,1,2,3)<limit);\n"
"}\n"
"kernel void softmax_plane(const device M *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" constant softmax_shape& s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= s.inside_size || (int)gid.y >= s.outside_size) return;\n"
" \n"
" auto axis_off=gid.y*s.axis_length*s.inside_size+gid.x;\n"
" auto axis_in=in+axis_off;\n"
" auto axis_out=out+axis_off;\n"
" \n"
" // get max\n"
" auto max1=axis_in[0];\n"
" for (int i=1; i<s.axis_length; i++) {\n"
" max1=max(max1,axis_in[i*s.inside_size]);\n"
" }\n"
" \n"
" // get sum\n"
" float sum1=0;\n"
" for (int i=0; i<s.axis_length; i++) {\n"
" sum1 += float(exp(axis_in[i*s.inside_size]-max1));\n"
" }\n"
" \n"
" // output\n"
" for (int i=0; i<s.axis_length; i++) {\n"
" axis_out[i*s.inside_size]=M(exp(float(axis_in[i*s.inside_size]-max1))/sum1);\n"
" }\n"
"}\n"
"kernel void softmax_on_reorder(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant softmax_shape& s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= s.inside_size || (int)gid.y >= s.outside_size) return;\n"
" \n"
" auto axis_off=gid.y*s.axis_length*s.inside_size+gid.x;\n"
" auto axis_in=in+axis_off;\n"
" auto axis_out=out+axis_off;\n"
" // get max\n"
" auto max4=softmax_filter(float4(axis_in[0]),0,s.flat_length);\n"
" for (int i=1; i<s.axis_length; i++) {\n"
" max4=max(max4,softmax_filter(float4(axis_in[i*s.inside_size]),i,s.flat_length));\n"
" }\n"
" float max1=softmax_max4(max4);\n"
" \n"
" // get sum\n"
" float4 sum4=0;\n"
" for (int i=0; i<s.axis_length; i++) {\n"
" sum4 += softmax_filter(exp(float4(axis_in[i*s.inside_size]-max1)),i,s.flat_length);\n"
" }\n"
" float sum1=softmax_sum4(sum4);\n"
" \n"
" // output\n"
" for (int i=0; i<s.axis_length; i++) {\n"
" axis_out[i*s.inside_size]=M4(exp(float4(axis_in[i*s.inside_size])-max1)/sum1);\n"
" }\n"
"}\n"
"kernel void softmax_off_reorder(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant softmax_shape& s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= s.inside_size || (int)gid.y >= s.outside_size) return;\n"
" auto axis_off=gid.y*s.axis_length*s.inside_size+gid.x;\n"
" auto axis_in=in+axis_off;\n"
" auto axis_out=out+axis_off;\n"
" // get max\n"
" auto max4=axis_in[0];\n"
" for (int i=1; i<s.axis_length; i++) {\n"
" max4=max(max4,axis_in[i*s.inside_size]);\n"
" }\n"
" // get sum\n"
" float4 sum4=0;\n"
" for (int i=0; i<s.axis_length; i++) {\n"
" sum4 += exp(float4(axis_in[i*s.inside_size]-max4));\n"
" }\n"
" // output\n"
" for (int i=0; i<s.axis_length; i++) {\n"
" axis_out[i*s.inside_size]=M4(exp(float4(axis_in[i*s.inside_size]-max4))/sum4);\n"
" }\n"
"}\n"
;
const char* shader_MetalLayerNorm_metal = 
"struct layernorm_constants {\n"
" int inside;\n"
" int outside;\n"
" float eps;\n"
" int has_gamma_beta;\n"
"};\n"
"kernel void layernorm_x1(const device M *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" constant layernorm_constants& cst [[buffer(2)]],\n"
" const device float *gamma [[buffer(3)]],\n"
" const device float *beta [[buffer(4)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.inside || (int)gid.y >= cst.outside) {\n"
" return;\n"
" }\n"
" auto in_data=in+gid.y*cst.inside;\n"
" auto out_data=out+gid.y*cst.inside;\n"
" float mean;\n"
" float sum=0.0f;\n"
" float square_sum=0.0f;\n"
" \n"
" for(int i=0; i<cst.inside; i++) {\n"
" sum += in_data[i];\n"
" }\n"
" mean=sum/cst.inside;\n"
" \n"
" for(int i=0; i<cst.inside; i++) {\n"
" float dis=(in_data[i]-mean);\n"
" square_sum += dis*dis;\n"
" }\n"
" float var=1.0/sqrt(square_sum/cst.inside+cst.eps);\n"
" \n"
" float norm=var*((float)in_data[gid.x]-mean);\n"
" if(cst.has_gamma_beta) {\n"
" out_data[gid.x]=(M)(norm*gamma[gid.x]+beta[gid.x]);\n"
" } else {\n"
" out_data[gid.x]=(M)(norm);\n"
" }\n"
"}\n"
"kernel void layernorm_x4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant layernorm_constants& cst [[buffer(2)]],\n"
" const device float4 *gamma [[buffer(3)]],\n"
" const device float4 *beta [[buffer(4)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.inside/4 || (int)gid.y >= cst.outside) {\n"
" return;\n"
" }\n"
" auto in_data=in+gid.y*cst.inside/4;\n"
" auto out_data=out+gid.y*cst.inside/4;\n"
" float mean;\n"
" float sum=0.0f;\n"
" float square_sum=0.0f;\n"
" \n"
" for(int i=0; i<cst.inside/4; i++) {\n"
" sum += in_data[i].x;\n"
" sum += in_data[i].y;\n"
" sum += in_data[i].z;\n"
" sum += in_data[i].w;\n"
" }\n"
" mean=sum/cst.inside;\n"
" \n"
" for(int i=0; i<cst.inside/4; i++) {\n"
" float dis=(in_data[i].x-mean);\n"
" square_sum += dis*dis;\n"
" dis=(in_data[i].y-mean);\n"
" square_sum += dis*dis;\n"
" dis=(in_data[i].z-mean);\n"
" square_sum += dis*dis;\n"
" dis=(in_data[i].w-mean);\n"
" square_sum += dis*dis;\n"
" }\n"
" float var=1.0/sqrt(square_sum/cst.inside+cst.eps);\n"
" \n"
" float4 norm=var*((float4)in_data[gid.x]-mean);\n"
" if(cst.has_gamma_beta) {\n"
" out_data[gid.x]=(M4)(norm*gamma[gid.x]+beta[gid.x]);\n"
" } else {\n"
" out_data[gid.x]=(M4)(norm);\n"
" }\n"
"}\n"
;
const char* shader_MetalConvolutionWinograd_metal = 
"struct winograd_constants {\n"
" int4 input_shape;\n"
" int4 output_shape;\n"
" int pad_x;\n"
" int pad_y;\n"
" int unit_width;\n"
" int unit_height;\n"
" int unit;\n"
" conv_activation_type activation;\n"
"};\n"
"static inline M4 get_input(const device M4 *input,int x,int y,constant winograd_constants &cst) {\n"
" return x<cst.input_shape.x && y<cst.input_shape.y && x >= 0 && y >= 0 ? input[x+y*cst.input_shape.x] : 0;\n"
"}\n"
"kernel void winograd_transform_source2_5_1(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant winograd_constants &cst [[buffer(2)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" auto pos=int3(gid);\n"
" if (pos.x<cst.unit_width && pos.y<cst.unit_height) {\n"
" int ix=pos.x*cst.unit-cst.pad_x;\n"
" int iy=pos.y*cst.unit-cst.pad_y;\n"
" auto z_in=in+pos.z*cst.input_shape.x*cst.input_shape.y;\n"
" auto S00=get_input(z_in,ix+0,iy+0,cst);\n"
" auto S10=get_input(z_in,ix+1,iy+0,cst);\n"
" auto S20=get_input(z_in,ix+2,iy+0,cst);\n"
" auto S30=get_input(z_in,ix+3,iy+0,cst);\n"
" auto S40=get_input(z_in,ix+4,iy+0,cst);\n"
" auto S50=get_input(z_in,ix+5,iy+0,cst);\n"
" auto S01=get_input(z_in,ix+0,iy+1,cst);\n"
" auto S11=get_input(z_in,ix+1,iy+1,cst);\n"
" auto S21=get_input(z_in,ix+2,iy+1,cst);\n"
" auto S31=get_input(z_in,ix+3,iy+1,cst);\n"
" auto S41=get_input(z_in,ix+4,iy+1,cst);\n"
" auto S51=get_input(z_in,ix+5,iy+1,cst);\n"
" auto S02=get_input(z_in,ix+0,iy+2,cst);\n"
" auto S12=get_input(z_in,ix+1,iy+2,cst);\n"
" auto S22=get_input(z_in,ix+2,iy+2,cst);\n"
" auto S32=get_input(z_in,ix+3,iy+2,cst);\n"
" auto S42=get_input(z_in,ix+4,iy+2,cst);\n"
" auto S52=get_input(z_in,ix+5,iy+2,cst);\n"
" auto S03=get_input(z_in,ix+0,iy+3,cst);\n"
" auto S13=get_input(z_in,ix+1,iy+3,cst);\n"
" auto S23=get_input(z_in,ix+2,iy+3,cst);\n"
" auto S33=get_input(z_in,ix+3,iy+3,cst);\n"
" auto S43=get_input(z_in,ix+4,iy+3,cst);\n"
" auto S53=get_input(z_in,ix+5,iy+3,cst);\n"
" auto S04=get_input(z_in,ix+0,iy+4,cst);\n"
" auto S14=get_input(z_in,ix+1,iy+4,cst);\n"
" auto S24=get_input(z_in,ix+2,iy+4,cst);\n"
" auto S34=get_input(z_in,ix+3,iy+4,cst);\n"
" auto S44=get_input(z_in,ix+4,iy+4,cst);\n"
" auto S54=get_input(z_in,ix+5,iy+4,cst);\n"
" auto S05=get_input(z_in,ix+0,iy+5,cst);\n"
" auto S15=get_input(z_in,ix+1,iy+5,cst);\n"
" auto S25=get_input(z_in,ix+2,iy+5,cst);\n"
" auto S35=get_input(z_in,ix+3,iy+5,cst);\n"
" auto S45=get_input(z_in,ix+4,iy+5,cst);\n"
" auto S55=get_input(z_in,ix+5,iy+5,cst);\n"
" auto m00=+S00-1.25*S02+0.25*S04;\n"
" auto m10=+S10-1.25*S12+0.25*S14;\n"
" auto m20=+S20-1.25*S22+0.25*S24;\n"
" auto m30=+S30-1.25*S32+0.25*S34;\n"
" auto m40=+S40-1.25*S42+0.25*S44;\n"
" auto m50=+S50-1.25*S52+0.25*S54;\n"
" auto m01=+0.666667*S01+0.666667*S02-0.166667*S03-0.166667*S04;\n"
" auto m11=+0.666667*S11+0.666667*S12-0.166667*S13-0.166667*S14;\n"
" auto m21=+0.666667*S21+0.666667*S22-0.166667*S23-0.166667*S24;\n"
" auto m31=+0.666667*S31+0.666667*S32-0.166667*S33-0.166667*S34;\n"
" auto m41=+0.666667*S41+0.666667*S42-0.166667*S43-0.166667*S44;\n"
" auto m51=+0.666667*S51+0.666667*S52-0.166667*S53-0.166667*S54;\n"
" auto m02=-0.666667*S01+0.666667*S02+0.166667*S03-0.166667*S04;\n"
" auto m12=-0.666667*S11+0.666667*S12+0.166667*S13-0.166667*S14;\n"
" auto m22=-0.666667*S21+0.666667*S22+0.166667*S23-0.166667*S24;\n"
" auto m32=-0.666667*S31+0.666667*S32+0.166667*S33-0.166667*S34;\n"
" auto m42=-0.666667*S41+0.666667*S42+0.166667*S43-0.166667*S44;\n"
" auto m52=-0.666667*S51+0.666667*S52+0.166667*S53-0.166667*S54;\n"
" auto m03=-0.0833333*S01-0.0416667*S02+0.0833333*S03+0.0416667*S04;\n"
" auto m13=-0.0833333*S11-0.0416667*S12+0.0833333*S13+0.0416667*S14;\n"
" auto m23=-0.0833333*S21-0.0416667*S22+0.0833333*S23+0.0416667*S24;\n"
" auto m33=-0.0833333*S31-0.0416667*S32+0.0833333*S33+0.0416667*S34;\n"
" auto m43=-0.0833333*S41-0.0416667*S42+0.0833333*S43+0.0416667*S44;\n"
" auto m53=-0.0833333*S51-0.0416667*S52+0.0833333*S53+0.0416667*S54;\n"
" auto m04=+0.0833333*S01-0.0416667*S02-0.0833333*S03+0.0416667*S04;\n"
" auto m14=+0.0833333*S11-0.0416667*S12-0.0833333*S13+0.0416667*S14;\n"
" auto m24=+0.0833333*S21-0.0416667*S22-0.0833333*S23+0.0416667*S24;\n"
" auto m34=+0.0833333*S31-0.0416667*S32-0.0833333*S33+0.0416667*S34;\n"
" auto m44=+0.0833333*S41-0.0416667*S42-0.0833333*S43+0.0416667*S44;\n"
" auto m54=+0.0833333*S51-0.0416667*S52-0.0833333*S53+0.0416667*S54;\n"
" auto m05=+4.0*S01-5.0*S03+S05;\n"
" auto m15=+4.0*S11-5.0*S13+S15;\n"
" auto m25=+4.0*S21-5.0*S23+S25;\n"
" auto m35=+4.0*S31-5.0*S33+S35;\n"
" auto m45=+4.0*S41-5.0*S43+S45;\n"
" auto m55=+4.0*S51-5.0*S53+S55;\n"
" int dst_x_origin=pos.z;\n"
" int dst_y_origin=cst.unit_width*pos.y+pos.x;\n"
" int dst_y_stride=cst.input_shape.z*4;\n"
" int dst_y=dst_y_origin/4;\n"
" int dst_x=dst_y_origin % 4+4*dst_x_origin;\n"
" int src_height=UP_DIV(cst.unit_width*cst.unit_height,4);\n"
" int stride=src_height*dst_y_stride;\n"
" auto xy_out=out+dst_y*dst_y_stride+dst_x;\n"
" *xy_out=+m00-1.25*m20+0.25*m40;\n"
" xy_out += stride; *xy_out=+0.666667*m10+0.666667*m20-0.166667*m30-0.166667*m40;\n"
" xy_out += stride; *xy_out=-0.666667*m10+0.666667*m20+0.166667*m30-0.166667*m40;\n"
" xy_out += stride; *xy_out=-0.0833333*m10-0.0416667*m20+0.0833333*m30+0.0416667*m40;\n"
" xy_out += stride; *xy_out=+0.0833333*m10-0.0416667*m20-0.0833333*m30+0.0416667*m40;\n"
" xy_out += stride; *xy_out=+4.0*m10-5.0*m30+m50;\n"
" xy_out += stride; *xy_out=+m01-1.25*m21+0.25*m41;\n"
" xy_out += stride; *xy_out=+0.666667*m11+0.666667*m21-0.166667*m31-0.166667*m41;\n"
" xy_out += stride; *xy_out=-0.666667*m11+0.666667*m21+0.166667*m31-0.166667*m41;\n"
" xy_out += stride; *xy_out=-0.0833333*m11-0.0416667*m21+0.0833333*m31+0.0416667*m41;\n"
" xy_out += stride; *xy_out=+0.0833333*m11-0.0416667*m21-0.0833333*m31+0.0416667*m41;\n"
" xy_out += stride; *xy_out=+4.0*m11-5.0*m31+m51;\n"
" xy_out += stride; *xy_out=+m02-1.25*m22+0.25*m42;\n"
" xy_out += stride; *xy_out=+0.666667*m12+0.666667*m22-0.166667*m32-0.166667*m42;\n"
" xy_out += stride; *xy_out=-0.666667*m12+0.666667*m22+0.166667*m32-0.166667*m42;\n"
" xy_out += stride; *xy_out=-0.0833333*m12-0.0416667*m22+0.0833333*m32+0.0416667*m42;\n"
" xy_out += stride; *xy_out=+0.0833333*m12-0.0416667*m22-0.0833333*m32+0.0416667*m42;\n"
" xy_out += stride; *xy_out=+4.0*m12-5.0*m32+m52;\n"
" xy_out += stride; *xy_out=+m03-1.25*m23+0.25*m43;\n"
" xy_out += stride; *xy_out=+0.666667*m13+0.666667*m23-0.166667*m33-0.166667*m43;\n"
" xy_out += stride; *xy_out=-0.666667*m13+0.666667*m23+0.166667*m33-0.166667*m43;\n"
" xy_out += stride; *xy_out=-0.0833333*m13-0.0416667*m23+0.0833333*m33+0.0416667*m43;\n"
" xy_out += stride; *xy_out=+0.0833333*m13-0.0416667*m23-0.0833333*m33+0.0416667*m43;\n"
" xy_out += stride; *xy_out=+4.0*m13-5.0*m33+m53;\n"
" xy_out += stride; *xy_out=+m04-1.25*m24+0.25*m44;\n"
" xy_out += stride; *xy_out=+0.666667*m14+0.666667*m24-0.166667*m34-0.166667*m44;\n"
" xy_out += stride; *xy_out=-0.666667*m14+0.666667*m24+0.166667*m34-0.166667*m44;\n"
" xy_out += stride; *xy_out=-0.0833333*m14-0.0416667*m24+0.0833333*m34+0.0416667*m44;\n"
" xy_out += stride; *xy_out=+0.0833333*m14-0.0416667*m24-0.0833333*m34+0.0416667*m44;\n"
" xy_out += stride; *xy_out=+4.0*m14-5.0*m34+m54;\n"
" xy_out += stride; *xy_out=+m05-1.25*m25+0.25*m45;\n"
" xy_out += stride; *xy_out=+0.666667*m15+0.666667*m25-0.166667*m35-0.166667*m45;\n"
" xy_out += stride; *xy_out=-0.666667*m15+0.666667*m25+0.166667*m35-0.166667*m45;\n"
" xy_out += stride; *xy_out=-0.0833333*m15-0.0416667*m25+0.0833333*m35+0.0416667*m45;\n"
" xy_out += stride; *xy_out=+0.0833333*m15-0.0416667*m25-0.0833333*m35+0.0416667*m45;\n"
" xy_out += stride; *xy_out=+4.0*m15-5.0*m35+m55;\n"
" }\n"
"}\n"
"kernel void winograd_transform_source2_3_1(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant winograd_constants &cst [[buffer(2)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" auto pos=int3(gid);\n"
" if (pos.x<cst.unit_width && pos.y<cst.unit_height) {\n"
" int ix=pos.x*cst.unit-cst.pad_x;\n"
" int iy=pos.y*cst.unit-cst.pad_y;\n"
" auto z_in=in+pos.z*cst.input_shape.x*cst.input_shape.y;\n"
" auto S00=get_input(z_in,ix+0,iy+0,cst);\n"
" auto S10=get_input(z_in,ix+1,iy+0,cst);\n"
" auto S20=get_input(z_in,ix+2,iy+0,cst);\n"
" auto S30=get_input(z_in,ix+3,iy+0,cst);\n"
" auto S01=get_input(z_in,ix+0,iy+1,cst);\n"
" auto S11=get_input(z_in,ix+1,iy+1,cst);\n"
" auto S21=get_input(z_in,ix+2,iy+1,cst);\n"
" auto S31=get_input(z_in,ix+3,iy+1,cst);\n"
" auto S02=get_input(z_in,ix+0,iy+2,cst);\n"
" auto S12=get_input(z_in,ix+1,iy+2,cst);\n"
" auto S22=get_input(z_in,ix+2,iy+2,cst);\n"
" auto S32=get_input(z_in,ix+3,iy+2,cst);\n"
" auto S03=get_input(z_in,ix+0,iy+3,cst);\n"
" auto S13=get_input(z_in,ix+1,iy+3,cst);\n"
" auto S23=get_input(z_in,ix+2,iy+3,cst);\n"
" auto S33=get_input(z_in,ix+3,iy+3,cst);\n"
" auto m00=+S00-S02;\n"
" auto m10=+S10-S12;\n"
" auto m20=+S20-S22;\n"
" auto m30=+S30-S32;\n"
" auto m01=+0.5*S01+0.5*S02;\n"
" auto m11=+0.5*S11+0.5*S12;\n"
" auto m21=+0.5*S21+0.5*S22;\n"
" auto m31=+0.5*S31+0.5*S32;\n"
" auto m02=-0.5*S01+0.5*S02;\n"
" auto m12=-0.5*S11+0.5*S12;\n"
" auto m22=-0.5*S21+0.5*S22;\n"
" auto m32=-0.5*S31+0.5*S32;\n"
" auto m03=-S01+S03;\n"
" auto m13=-S11+S13;\n"
" auto m23=-S21+S23;\n"
" auto m33=-S31+S33;\n"
" int dst_x_origin=pos.z;\n"
" int dst_y_origin=cst.unit_width*pos.y+pos.x;\n"
" int dst_y_stride=cst.input_shape.z*4;\n"
" int dst_y=dst_y_origin/4;\n"
" int dst_x=dst_y_origin % 4+4*dst_x_origin;\n"
" int src_height=UP_DIV(cst.unit_width*cst.unit_height,4);\n"
" int stride=src_height*dst_y_stride;\n"
" auto xy_out=out+dst_y*dst_y_stride+dst_x;\n"
" *xy_out=+m00-m20;\n"
" xy_out += stride; *xy_out=+0.5*m10+0.5*m20;\n"
" xy_out += stride; *xy_out=-0.5*m10+0.5*m20;\n"
" xy_out += stride; *xy_out=-m10+m30;\n"
" xy_out += stride; *xy_out=+m01-m21;\n"
" xy_out += stride; *xy_out=+0.5*m11+0.5*m21;\n"
" xy_out += stride; *xy_out=-0.5*m11+0.5*m21;\n"
" xy_out += stride; *xy_out=-m11+m31;\n"
" xy_out += stride; *xy_out=+m02-m22;\n"
" xy_out += stride; *xy_out= +0.5*m12+0.5*m22;\n"
" xy_out += stride; *xy_out=-0.5*m12+0.5*m22;\n"
" xy_out += stride; *xy_out=-m12+m32;\n"
" xy_out += stride; *xy_out=+m03-m23;\n"
" xy_out += stride; *xy_out=+0.5*m13+0.5*m23;\n"
" xy_out += stride; *xy_out=-0.5*m13+0.5*m23;\n"
" xy_out += stride; *xy_out=-m13+m33;\n"
" }\n"
"}\n"
"static inline void set_output(constant winograd_constants &cst,device M4 *output,int x,int y,M4 V) {\n"
" output[y*cst.output_shape.x+x]=activate(V,cst.activation);\n"
"}\n"
"kernel void winograd_transform_dest2_5_1(const device M4 *in [[buffer(0)]],\n"
" const device M4 *biasTerms [[buffer(1)]],\n"
" device M4 *out [[buffer(2)]],\n"
" constant winograd_constants &cst [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" auto pos=int3(gid);\n"
" if (pos.x<cst.unit_width && pos.y<cst.unit_height) {\n"
" int dst_w=UP_DIV(cst.unit_width*cst.unit_height,4);\n"
" int dst_x_origin=cst.unit_width*pos.y+pos.x;\n"
" int dst_x=dst_x_origin/4;\n"
" int dst_y=4*pos.z+dst_x_origin % 4;\n"
" int dst_y_stride=dst_w*36;\n"
" auto xy_in=in+dst_y*dst_y_stride+dst_x;\n"
" auto S00=*xy_in; xy_in += dst_w;\n"
" auto S10=*xy_in; xy_in += dst_w;\n"
" auto S20=*xy_in; xy_in += dst_w;\n"
" auto S30=*xy_in; xy_in += dst_w;\n"
" auto S40=*xy_in; xy_in += dst_w;\n"
" auto S50=*xy_in; xy_in += dst_w;\n"
" auto S01=*xy_in; xy_in += dst_w;\n"
" auto S11=*xy_in; xy_in += dst_w;\n"
" auto S21=*xy_in; xy_in += dst_w;\n"
" auto S31=*xy_in; xy_in += dst_w;\n"
" auto S41=*xy_in; xy_in += dst_w;\n"
" auto S51=*xy_in; xy_in += dst_w;\n"
" auto S02=*xy_in; xy_in += dst_w;\n"
" auto S12=*xy_in; xy_in += dst_w;\n"
" auto S22=*xy_in; xy_in += dst_w;\n"
" auto S32=*xy_in; xy_in += dst_w;\n"
" auto S42=*xy_in; xy_in += dst_w;\n"
" auto S52=*xy_in; xy_in += dst_w;\n"
" auto S03=*xy_in; xy_in += dst_w;\n"
" auto S13=*xy_in; xy_in += dst_w;\n"
" auto S23=*xy_in; xy_in += dst_w;\n"
" auto S33=*xy_in; xy_in += dst_w;\n"
" auto S43=*xy_in; xy_in += dst_w;\n"
" auto S53=*xy_in; xy_in += dst_w;\n"
" auto S04=*xy_in; xy_in += dst_w;\n"
" auto S14=*xy_in; xy_in += dst_w;\n"
" auto S24=*xy_in; xy_in += dst_w;\n"
" auto S34=*xy_in; xy_in += dst_w;\n"
" auto S44=*xy_in; xy_in += dst_w;\n"
" auto S54=*xy_in; xy_in += dst_w;\n"
" auto S05=*xy_in; xy_in += dst_w;\n"
" auto S15=*xy_in; xy_in += dst_w;\n"
" auto S25=*xy_in; xy_in += dst_w;\n"
" auto S35=*xy_in; xy_in += dst_w;\n"
" auto S45=*xy_in; xy_in += dst_w;\n"
" auto S55=*xy_in;\n"
" auto m00=+S00+S01+S02+S03+S04;\n"
" auto m10=+S10+S11+S12+S13+S14;\n"
" auto m20=+S20+S21+S22+S23+S24;\n"
" auto m30=+S30+S31+S32+S33+S34;\n"
" auto m40=+S40+S41+S42+S43+S44;\n"
" auto m50=+S50+S51+S52+S53+S54;\n"
" auto m01=+S01-S02+2.0*S03-2.0*S04+S05;\n"
" auto m11=+S11-S12+2.0*S13-2.0*S14+S15;\n"
" auto m21=+S21-S22+2.0*S23-2.0*S24+S25;\n"
" auto m31=+S31-S32+2.0*S33-2.0*S34+S35;\n"
" auto m41=+S41-S42+2.0*S43-2.0*S44+S45;\n"
" auto m51=+S51-S52+2.0*S53-2.0*S54+S55;\n"
" // write output\n"
" auto b4=biasTerms[int(pos.z)];\n"
" int oy=pos.y*cst.unit;\n"
" int ox=pos.x*cst.unit;\n"
" auto z_out=out+pos.z*cst.output_shape.x*cst.output_shape.y;\n"
" \n"
" /* if true */ {\n"
" set_output(cst,z_out,ox+0,oy+0,b4+m00+m10+m20+m30+m40);\n"
" }\n"
" if (ox+1<cst.output_shape.x) {\n"
" set_output(cst,z_out,ox+1,oy+0,b4+m10-m20+2.0*m30-2.0*m40+m50);\n"
" }\n"
" if (oy+1<cst.output_shape.y) {\n"
" set_output(cst,z_out,ox+0,oy+1,b4+m01+m11+m21+m31+m41);\n"
" }\n"
" if (ox+1<cst.output_shape.x && oy+1<cst.output_shape.y) {\n"
" set_output(cst,z_out,ox+1,oy+1,b4+m11-m21+2.0*m31-2.0*m41+m51);\n"
" }\n"
" }\n"
"}\n"
"kernel void winograd_transform_dest2_3_1(const device M4 *in [[buffer(0)]],\n"
" const device M4 *biasTerms [[buffer(1)]],\n"
" device M4 *out [[buffer(2)]],\n"
" constant winograd_constants &cst [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" auto pos=int3(gid);\n"
" if (pos.x<cst.unit_width && pos.y<cst.unit_height) {\n"
" int dst_w=UP_DIV(cst.unit_width*cst.unit_height,4);\n"
" int dst_x_origin=cst.unit_width*pos.y+pos.x;\n"
" int dst_x=dst_x_origin/4;\n"
" int dst_y=4*pos.z+dst_x_origin % 4;\n"
" int dst_y_stride=dst_w*16;\n"
" auto xy_in=in+dst_y*dst_y_stride+dst_x;\n"
" auto S00=*xy_in; xy_in += dst_w;\n"
" auto S10=*xy_in; xy_in += dst_w;\n"
" auto S20=*xy_in; xy_in += dst_w;\n"
" auto S30=*xy_in; xy_in += dst_w;\n"
" auto S01=*xy_in; xy_in += dst_w;\n"
" auto S11=*xy_in; xy_in += dst_w;\n"
" auto S21=*xy_in; xy_in += dst_w;\n"
" auto S31=*xy_in; xy_in += dst_w;\n"
" auto S02=*xy_in; xy_in += dst_w;\n"
" auto S12=*xy_in; xy_in += dst_w;\n"
" auto S22=*xy_in; xy_in += dst_w;\n"
" auto S32=*xy_in; xy_in += dst_w;\n"
" auto S03=*xy_in; xy_in += dst_w;\n"
" auto S13=*xy_in; xy_in += dst_w;\n"
" auto S23=*xy_in; xy_in += dst_w;\n"
" auto S33=*xy_in;\n"
" auto m00=+S00+S01+S02;\n"
" auto m10=+S10+S11+S12;\n"
" auto m20=+S20+S21+S22;\n"
" auto m30=+S30+S31+S32;\n"
" auto m01=+S01-S02+S03;\n"
" auto m11=+S11-S12+S13;\n"
" auto m21=+S21-S22+S23;\n"
" auto m31=+S31-S32+S33;\n"
" // write output\n"
" auto b4=biasTerms[int(pos.z)];\n"
" int oy=pos.y*cst.unit;\n"
" int ox=pos.x*cst.unit;\n"
" auto z_out=out+pos.z*cst.output_shape.x*cst.output_shape.y;\n"
" \n"
" /* if true */ {\n"
" set_output(cst,z_out,ox+0,oy+0,b4+m00+m10+m20);\n"
" }\n"
" if (ox+1<cst.output_shape.x) {\n"
" set_output(cst,z_out,ox+1,oy+0,b4+m10-m20+m30);\n"
" }\n"
" if (oy+1<cst.output_shape.y) {\n"
" set_output(cst,z_out,ox+0,oy+1,b4+m01+m11+m21);\n"
" }\n"
" if (ox+1<cst.output_shape.x && oy+1<cst.output_shape.y) {\n"
" set_output(cst,z_out,ox+1,oy+1,b4+m11-m21+m31);\n"
" }\n"
" }\n"
"}\n"
;
const char* shader_MetalMatMul_metal = 
"struct matmul_shape {\n"
" int4 mat_size;\n"
" int4 in_stride;\n"
"};\n"
"kernel void matmul(const device M *in0 [[buffer(0)]],\n"
" const device M *in1 [[buffer(1)]],\n"
" device M *out [[buffer(2)]],\n"
" constant matmul_shape &s [[buffer(3)]],\n"
" uint2 gid[[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.mat_size.x && (int)gid.y<s.mat_size.y) {\n"
" auto off_in0=in0+int(gid.y)*s.in_stride.x;\n"
" auto off_in1=in1+int(gid.x)*s.in_stride.z;\n"
" FLOAT V=0.f;\n"
" for (int i=0; i<s.mat_size.z; i++,off_in0 += s.in_stride.y,off_in1 += s.in_stride.w) {\n"
" V += FLOAT(*off_in0)*FLOAT(*off_in1);\n"
" }\n"
" out[int(gid.y)*s.mat_size.x+int(gid.x)]=M(V);\n"
" }\n"
"}\n"
"kernel void matmul_bias(const device M *in0 [[buffer(0)]],\n"
" const device M *in1 [[buffer(1)]],\n"
" const device M *biasValue [[buffer(2)]],\n"
" device M *out [[buffer(3)]],\n"
" constant matmul_shape &s [[buffer(4)]],\n"
" uint2 gid[[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.mat_size.x && (int)gid.y<s.mat_size.y) {\n"
" auto off_in0=in0+int(gid.y)*s.in_stride.x;\n"
" auto off_in1=in1+int(gid.x)*s.in_stride.z;\n"
" FLOAT V=0.f;\n"
" for (int i=0; i<s.mat_size.z; i++,off_in0 += s.in_stride.y,off_in1 += s.in_stride.w) {\n"
" V += FLOAT(*off_in0)*FLOAT(*off_in1);\n"
" }\n"
" out[int(gid.y)*s.mat_size.x+int(gid.x)]=M(V)+biasValue[(int)(gid.x)];\n"
" }\n"
"}\n"
;
const char* shader_MetalScale_metal = 
"struct scale_shape {\n"
" int size;\n"
" int steps;\n"
" int batch;\n"
"};\n"
"kernel void scale_ca(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant scale_shape &s [[buffer(2)]],\n"
" const device float4 *scales [[buffer(3)]],\n"
" const device float4 *biasTerms [[buffer(4)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= s.size || (int)gid.y >= s.steps*s.batch) return;\n"
" int z=gid.y % s.steps;\n"
" out[int(gid.y)*s.size+int(gid.x)] =\n"
" in [int(gid.y)*s.size+int(gid.x)]*M4(scales[z])+M4(biasTerms[z]);\n"
"}\n"
;
const char* shader_MetalDeconvolution_metal = 
"struct deconv_constants {\n"
" int input_width;\n"
" int input_height;\n"
" int input_size;\n"
" int input_slice;\n"
" int output_width;\n"
" int output_height;\n"
" int output_size;\n"
" int output_slice;\n"
" \n"
" int kernel_x;\n"
" int kernel_y;\n"
" int kernel_size;\n"
" int stride_x;\n"
" int stride_y;\n"
" int pad_x;\n"
" int pad_y;\n"
" int dilation_x;\n"
" int dilation_y;\n"
" \n"
" int delta_ky;\n"
" int delta_kx;\n"
" int delta_iy;\n"
" int delta_ix;\n"
" int has_bias;\n"
" int batch;\n"
" conv_activation_type activation;\n"
"};\n"
"kernel void deconv(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant deconv_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n"
" \n"
" int b=gid.z/cst.output_slice;\n"
" int o=gid.z % cst.output_slice;\n"
" FLOAT4 result=cst.has_bias ? FLOAT4(biasTerms[o]) : 0;\n"
" int oy=(int)gid.y+cst.pad_y;\n"
" int ox=(int)gid.x+cst.pad_x;\n"
" int max_sy=min((cst.input_height-1)*cst.stride_y,oy/cst.stride_y*cst.stride_y);\n"
" int max_sx=min((cst.input_width-1)*cst.stride_x,ox/cst.stride_x*cst.stride_x);\n"
" int min_ky=UP_DIV(oy-max_sy,cst.dilation_y);\n"
" int min_kx=UP_DIV(ox-max_sx,cst.dilation_x);\n"
" \n"
" if ((oy-min_ky*cst.dilation_y) % cst.stride_y == 0 && (ox-min_kx*cst.dilation_x) % cst.stride_x == 0) {\n"
" int min_sy=max(0,ROUND_UP(oy+cst.dilation_y-cst.kernel_y*cst.dilation_y,cst.stride_y));\n"
" int min_sx=max(0,ROUND_UP(ox+cst.dilation_x-cst.kernel_x*cst.dilation_x,cst.stride_x));\n"
" int max_ky=(oy-min_sy)/cst.dilation_y;\n"
" int max_kx=(ox-min_sx)/cst.dilation_x;\n"
" int min_iy=(oy-max_ky*cst.dilation_y)/cst.stride_y;\n"
" int min_ix=(ox-max_kx*cst.dilation_x)/cst.stride_x;\n"
" \n"
" auto o_wt=wt+o*cst.input_slice*cst.kernel_size;\n"
" auto b_in=in+b*cst.input_slice*cst.input_size;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" for (auto ky=max_ky,iy=min_iy; ky >= min_ky; ky -= cst.delta_ky,iy += cst.delta_iy) {\n"
" for (auto kx=max_kx,ix=min_ix; kx >= min_kx; kx -= cst.delta_kx,ix += cst.delta_ix) {\n"
" auto wt4=o_wt[z*cst.kernel_size+ky*cst.kernel_x+kx];\n"
" auto in4=b_in[z*cst.input_size+iy*cst.input_width+ix];\n"
" result += FLOAT4(in4*wt4);\n"
" }\n"
" }\n"
" }\n"
" }\n"
" out[(int)gid.z*cst.output_size+(int)gid.y*cst.output_width+(int)gid.x]=activate(M4(result),cst.activation);\n"
"}\n"
"kernel void deconv_depthwise(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant deconv_constants& cst [[buffer(2)]],\n"
" const device M4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n"
" \n"
" FLOAT4 result=FLOAT4(biasTerms[(int)(gid.z % cst.input_slice)]);\n"
" \n"
" int oy=(int)gid.y+cst.pad_y;\n"
" int ox=(int)gid.x+cst.pad_x;\n"
" int max_sy=min((cst.input_height-1)*cst.stride_y,oy/cst.stride_y*cst.stride_y);\n"
" int max_sx=min((cst.input_width-1)*cst.stride_x,ox/cst.stride_x*cst.stride_x);\n"
" int min_ky=UP_DIV(oy-max_sy,cst.dilation_y);\n"
" int min_kx=UP_DIV(ox-max_sx,cst.dilation_x);\n"
" \n"
" if ((oy-min_ky*cst.dilation_y) % cst.stride_y == 0 && (ox-min_kx*cst.dilation_x) % cst.stride_x == 0) {\n"
" int min_sy=max(0,ROUND_UP(oy+cst.dilation_y-cst.kernel_y*cst.dilation_y,cst.stride_y));\n"
" int min_sx=max(0,ROUND_UP(ox+cst.dilation_x-cst.kernel_x*cst.dilation_x,cst.stride_x));\n"
" int max_ky=(oy-min_sy)/cst.dilation_y;\n"
" int max_kx=(ox-min_sx)/cst.dilation_x;\n"
" int min_iy=(oy-max_ky*cst.dilation_y)/cst.stride_y;\n"
" int min_ix=(ox-max_kx*cst.dilation_x)/cst.stride_x;\n"
" \n"
" auto z_wt=wt+(int)gid.z*cst.kernel_size;\n"
" auto z_in=in+(int)gid.z*cst.input_size;\n"
" for (auto ky=max_ky,iy=min_iy; ky >= min_ky; ky -= cst.delta_ky,iy += cst.delta_iy) {\n"
" for (auto kx=max_kx,ix=min_ix; kx >= min_kx; kx -= cst.delta_kx,ix += cst.delta_ix) {\n"
" auto wt4=z_wt[ky*cst.kernel_x+kx];\n"
" auto in4=z_in[iy*cst.input_width+ix];\n"
" result += FLOAT4(in4*wt4);\n"
" }\n"
" }\n"
" }\n"
" out[(int)gid.z*cst.output_size+(int)gid.y*cst.output_width+(int)gid.x]=activate(M4(result),cst.activation);\n"
"}\n"
;
const char* shader_MetalPooling_metal = 
"struct pooling_sizes {\n"
" int input_width;\n"
" int input_height;\n"
" int output_width;\n"
" int output_height;\n"
" int slice;\n"
" int kernel_width;\n"
" int kernel_height;\n"
" int stride_width;\n"
" int stride_height;\n"
" int pad_width;\n"
" int pad_height;\n"
"};\n"
"kernel void pooling_max(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant pooling_sizes& s [[buffer(2)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if (any(gid >= uint3(s.output_width,s.output_height,s.slice))) return;\n"
" \n"
" int off_x=gid.x*s.stride_width-s.pad_width;\n"
" int off_y=gid.y*s.stride_height-s.pad_height;\n"
" int x_max=s.input_width-1;\n"
" int y_max=s.input_height-1;\n"
" int ex=off_x+s.kernel_width;\n"
" int ey=off_y+s.kernel_height;\n"
" \n"
" auto z_in=in+(int)gid.z*s.input_width*s.input_height;\n"
" auto result=M4(z_in[clamp(off_y,0,y_max)*s.input_width+clamp(off_x,0,x_max)]);\n"
" for (int y=off_y; y<ey; y++) {\n"
" auto y_in=z_in+clamp(y,0,y_max)*s.input_width;\n"
" for (int x=off_x; x<ex; x++) {\n"
" result=max(result,y_in[clamp(x,0,x_max)]);\n"
" }\n"
" }\n"
" out[(int)gid.z*s.output_width*s.output_height+(int)gid.y*s.output_width+(int)gid.x]=result;\n"
"}\n"
"kernel void pooling_avg(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant pooling_sizes& s [[buffer(2)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if (any(gid >= uint3(s.output_width,s.output_height,s.slice))) return;\n"
" \n"
" int off_x=gid.x*s.stride_width-s.pad_width;\n"
" int off_y=gid.y*s.stride_height-s.pad_height;\n"
" int sx=off_x+max(0,-off_x);\n"
" int sy=off_y+max(0,-off_y);\n"
" int ex=off_x+min(s.kernel_width,s.input_width-off_x);\n"
" int ey=off_y+min(s.kernel_height,s.input_height-off_y);\n"
" \n"
" FLOAT4 result=0;\n"
" auto z_in=in+(int)gid.z*s.input_width*s.input_height;\n"
" for (int y=sy; y<ey; y++) {\n"
" for (int x=sx; x<ex; x++) {\n"
" result += FLOAT4(z_in[y*s.input_width+x]);\n"
" }\n"
" }\n"
" int count=(ey-sy)*(ex-sx);\n"
" FLOAT4 div=count>0 ? 1.f/count : 1;\n"
" out[(int)gid.z*s.output_width*s.output_height+(int)gid.y*s.output_width+(int)gid.x]=M4(result*div);\n"
"}\n"
;
const char* shader_MetalROIPooling_metal = 
"struct ROI_shape {\n"
" int input_width;\n"
" int input_height;\n"
" int input_size;\n"
" int output_width;\n"
" int output_height;\n"
" int output_size;\n"
" int slices;\n"
" float spatial_scale;\n"
"};\n"
"kernel void ROI_pooling(const device M4 *in [[buffer(0)]],\n"
" const device M *roi [[buffer(1)]],\n"
" device M4 *out [[buffer(2)]],\n"
" constant ROI_shape &s [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= s.output_width || (int)gid.y >= s.output_height) return;\n"
" \n"
" int ob=gid.z/s.slices;\n"
" int iz=gid.z % s.slices;\n"
" \n"
" auto b_roi=roi+ob*8; // roundup(5,4)=8\n"
" int ib=int(b_roi[0]);\n"
" int x1=round(float(b_roi[1])*s.spatial_scale);\n"
" int y1=round(float(b_roi[2])*s.spatial_scale);\n"
" int x2=round(float(b_roi[3])*s.spatial_scale);\n"
" int y2=round(float(b_roi[4])*s.spatial_scale);\n"
" \n"
" int roi_w=max(x2-x1+1,1);\n"
" int roi_h=max(y2-y1+1,1);\n"
" auto bin_size_w=(M)roi_w/s.output_width;\n"
" auto bin_size_h=(M)roi_h/s.output_height;\n"
" \n"
" int w_start=clamp(x1+(int)floor(gid.x*bin_size_w) ,0,s.input_width);\n"
" int w_end=clamp(x1+(int)ceil((gid.x+1)*bin_size_w),0,s.input_width);\n"
" int h_start=clamp(y1+(int)floor(gid.y*bin_size_h) ,0,s.input_height);\n"
" int h_end=clamp(y1+(int)ceil((gid.y+1)*bin_size_h),0,s.input_height);\n"
" \n"
" int is_empty=(h_end <= h_start) || (w_end <= w_start);\n"
" auto z_in=in+(ib*s.slices+iz)*s.input_size;\n"
" auto max4=is_empty ? 0 : z_in[h_start*s.input_width+w_start];\n"
" for (int y=h_start; y<h_end; y++) {\n"
" auto y_in=z_in+y*s.input_width;\n"
" for (int x=w_start; x<w_end; x++) {\n"
" max4=max(max4,y_in[x]);\n"
" }\n"
" }\n"
" out[int(gid.z)*s.output_size+int(gid.y)*s.output_width+int(gid.x)]=max4;\n"
"}\n"
;
const char* shader_MetalCast_metal = 
"using namespace metal;\n"
"kernel void cast_float_to_int32(const device M *in [[buffer(0)]],\n"
" device int *out [[buffer(1)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" out[int(gid)]=int(in[int(gid)]);\n"
"}\n"
"kernel void cast_int32_to_float(const device int *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" out[int(gid)]=M(in[int(gid)]);\n"
"}\n"
"kernel void cast_uint8_to_float(const device uchar *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" out[int(gid)]=M(in[int(gid)]);\n"
"}\n"
"kernel void cast_uint8_to_int(const device uchar *in [[buffer(0)]],\n"
" device int *out [[buffer(1)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" out[int(gid)]=M(in[int(gid)]);\n"
"}\n"
"kernel void cast_float_to_uint8(const device M *in [[buffer(0)]],\n"
" device uchar *out [[buffer(1)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" out[int(gid)]=uchar(in[int(gid)]);\n"
"}\n"
;
const char* shader_MetalConvolution1x1_metal = 
"#define CONV_UNROLL (4)\n"
"#define CONV_UNROLL_L (8)\n"
"struct conv1x1_constants {\n"
" int input_size;\n"
" int input_slice;\n"
" int output_width;\n"
" int output_height;\n"
" int output_size;\n"
" int output_slice;\n"
" int output_channel;\n"
" int batch;\n"
" conv_activation_type activation;\n"
"};\n"
"kernel void conv1x1_w1h1(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n"
" int idx_w=gid.x;\n"
" int idx_h=gid.y;\n"
" int idx_c=gid.z % cst.output_slice;\n"
" int idx_b=gid.z/cst.output_slice;\n"
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
" auto xy_in0=in+(int)idx_b*cst.input_slice*cst.input_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_out=out+(int)idx_b*cst.output_slice*cst.output_size+idx_c*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" auto biasValue=FLOAT4(biasTerms[idx_c]);\n"
" FLOAT4 result0=biasValue;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in40=xy_in0[0];\n"
" auto w=xy_wt[z];\n"
" result0 += FLOAT4(in40*w);\n"
" xy_in0 += cst.input_size;\n"
" }\n"
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
"}\n"
"kernel void conv1x1_g1z4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;\n"
" \n"
" int rx=gid.x*CONV_UNROLL;\n"
" int uz=gid.y;\n"
" auto xy_wt=wt+uz*cst.input_slice;\n"
" auto xy_in0=in+(int)gid.z*cst.input_slice*cst.input_size+rx+0;\n"
" auto xy_out=out+(int)gid.z*cst.output_slice*cst.output_size+uz*cst.output_size+rx;\n"
" auto biasValue=FLOAT4(biasTerms[uz]);\n"
" FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n"
" int computeSize=min(cst.output_size-rx,CONV_UNROLL);\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in40=*xy_in0;\n"
" auto in41=*(xy_in0+1);\n"
" auto in42=*(xy_in0+2);\n"
" auto in43=*(xy_in0+3);\n"
" auto w=xy_wt[z];\n"
" \n"
" result0 += FLOAT4(in40*w);\n"
" result1 += FLOAT4(in41*w);\n"
" result2 += FLOAT4(in42*w);\n"
" result3 += FLOAT4(in43*w);\n"
" xy_in0 += cst.input_size;\n"
" }\n"
" \n"
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
" if (computeSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
" if (computeSize>2) {xy_out[2]=activate(M4(result2),cst.activation); }\n"
" if (computeSize>3) {xy_out[3]=activate(M4(result3),cst.activation); }\n"
"}\n"
"kernel void conv1x1_g1z8(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*CONV_UNROLL_L >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;\n"
" int rx=gid.x*CONV_UNROLL_L;\n"
" int uz=gid.y;\n"
" auto xy_wt=wt+uz*cst.input_slice;\n"
" auto xy_in0=in+(int)gid.z*cst.input_slice*cst.input_size+rx+0;\n"
" auto xy_out=out+(int)gid.z*cst.output_slice*cst.output_size+uz*cst.output_size+rx;\n"
" auto biasValue=FLOAT4(biasTerms[uz]);\n"
" FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n"
" FLOAT4 result4=biasValue,result5=biasValue,result6=biasValue,result7=biasValue;\n"
" int computeSize=min(cst.output_size-rx,CONV_UNROLL_L);\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in40=xy_in0[0];\n"
" auto in41=xy_in0[1];\n"
" auto in42=xy_in0[2];\n"
" auto in43=xy_in0[3];\n"
" auto in44=xy_in0[4];\n"
" auto in45=xy_in0[5];\n"
" auto in46=xy_in0[6];\n"
" auto in47=xy_in0[7];\n"
" auto w=xy_wt[z];\n"
" result0 += FLOAT4(in40*w);\n"
" result1 += FLOAT4(in41*w);\n"
" result2 += FLOAT4(in42*w);\n"
" result3 += FLOAT4(in43*w);\n"
" result4 += FLOAT4(in44*w);\n"
" result5 += FLOAT4(in45*w);\n"
" result6 += FLOAT4(in46*w);\n"
" result7 += FLOAT4(in47*w);\n"
" xy_in0 += cst.input_size;\n"
" }\n"
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
" if (computeSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
" if (computeSize>2) {xy_out[2]=activate(M4(result2),cst.activation); }\n"
" if (computeSize>3) {xy_out[3]=activate(M4(result3),cst.activation); }\n"
" if (computeSize>4) {xy_out[4]=activate(M4(result4),cst.activation); }\n"
" if (computeSize>5) {xy_out[5]=activate(M4(result5),cst.activation); }\n"
" if (computeSize>6) {xy_out[6]=activate(M4(result6),cst.activation); }\n"
" if (computeSize>7) {xy_out[7]=activate(M4(result7),cst.activation); }\n"
"}\n"
"kernel void conv1x1_w4h2(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*4 >= cst.output_width || (int)gid.y*2 >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n"
" int idx_w=gid.x << 2;\n"
" int idx_h=gid.y << 1;\n"
" int idx_c=gid.z % cst.output_slice;\n"
" int idx_b=gid.z/cst.output_slice;\n"
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
" auto xy_in0=in+(int)idx_b*cst.input_slice*cst.input_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_out=out+(int)idx_b*cst.output_slice*cst.output_size+idx_c*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" auto biasValue=FLOAT4(biasTerms[idx_c]);\n"
" FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n"
" FLOAT4 result4=biasValue,result5=biasValue,result6=biasValue,result7=biasValue;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in40=xy_in0[0];\n"
" auto in41=xy_in0[1];\n"
" auto in42=xy_in0[2];\n"
" auto in43=xy_in0[3];\n"
" auto in44=xy_in0[cst.output_width+0];\n"
" auto in45=xy_in0[cst.output_width+1];\n"
" auto in46=xy_in0[cst.output_width+2];\n"
" auto in47=xy_in0[cst.output_width+3];\n"
" auto w=xy_wt[z];\n"
" result0 += FLOAT4(in40*w);\n"
" result1 += FLOAT4(in41*w);\n"
" result2 += FLOAT4(in42*w);\n"
" result3 += FLOAT4(in43*w);\n"
" result4 += FLOAT4(in44*w);\n"
" result5 += FLOAT4(in45*w);\n"
" result6 += FLOAT4(in46*w);\n"
" result7 += FLOAT4(in47*w);\n"
" xy_in0 += cst.input_size;\n"
" }\n"
" int widthSize=min(cst.output_width-idx_w,4);\n"
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
" if (widthSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
" if (widthSize>2) {xy_out[2]=activate(M4(result2),cst.activation); }\n"
" if (widthSize>3) {xy_out[3]=activate(M4(result3),cst.activation); }\n"
" \n"
" int heightSize=min(cst.output_height-idx_h,2);\n"
" if(heightSize>1) {\n"
" /* true */ {xy_out[cst.output_width+0]=activate(M4(result4),cst.activation); }\n"
" if (widthSize>1) {xy_out[cst.output_width+1]=activate(M4(result5),cst.activation); }\n"
" if (widthSize>2) {xy_out[cst.output_width+2]=activate(M4(result6),cst.activation); }\n"
" if (widthSize>3) {xy_out[cst.output_width+3]=activate(M4(result7),cst.activation); }\n"
" }\n"
"}\n"
"kernel void conv1x1_w4h4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*4 >= cst.output_width || (int)gid.y*4 >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n"
" int idx_w=gid.x << 2;\n"
" int idx_h=gid.y << 2;\n"
" int idx_c=gid.z % cst.output_slice;\n"
" int idx_b=gid.z/cst.output_slice;\n"
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
" auto xy_in0=in+(int)idx_b*cst.input_slice*cst.input_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_out=out+(int)idx_b*cst.output_slice*cst.output_size+idx_c*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" auto biasValue=FLOAT4(biasTerms[idx_c]);\n"
" FLOAT4 result00=biasValue,result01=biasValue,result02=biasValue,result03=biasValue;\n"
" FLOAT4 result10=biasValue,result11=biasValue,result12=biasValue,result13=biasValue;\n"
" FLOAT4 result20=biasValue,result21=biasValue,result22=biasValue,result23=biasValue;\n"
" FLOAT4 result30=biasValue,result31=biasValue,result32=biasValue,result33=biasValue;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in00=xy_in0[0];\n"
" auto in01=xy_in0[1];\n"
" auto in02=xy_in0[2];\n"
" auto in03=xy_in0[3];\n"
" auto in10=xy_in0[cst.output_width+0];\n"
" auto in11=xy_in0[cst.output_width+1];\n"
" auto in12=xy_in0[cst.output_width+2];\n"
" auto in13=xy_in0[cst.output_width+3];\n"
" \n"
" auto in20=xy_in0[cst.output_width+cst.output_width+0];\n"
" auto in21=xy_in0[cst.output_width+cst.output_width+1];\n"
" auto in22=xy_in0[cst.output_width+cst.output_width+2];\n"
" auto in23=xy_in0[cst.output_width+cst.output_width+3];\n"
" auto in30=xy_in0[cst.output_width+cst.output_width+cst.output_width+0];\n"
" auto in31=xy_in0[cst.output_width+cst.output_width+cst.output_width+1];\n"
" auto in32=xy_in0[cst.output_width+cst.output_width+cst.output_width+2];\n"
" auto in33=xy_in0[cst.output_width+cst.output_width+cst.output_width+3];\n"
" auto w=xy_wt[z];\n"
" result00 += FLOAT4(in00*w);\n"
" result01 += FLOAT4(in01*w);\n"
" result02 += FLOAT4(in02*w);\n"
" result03 += FLOAT4(in03*w);\n"
" result10 += FLOAT4(in10*w);\n"
" result11 += FLOAT4(in11*w);\n"
" result12 += FLOAT4(in12*w);\n"
" result13 += FLOAT4(in13*w);\n"
" \n"
" result20 += FLOAT4(in20*w);\n"
" result21 += FLOAT4(in21*w);\n"
" result22 += FLOAT4(in22*w);\n"
" result23 += FLOAT4(in23*w);\n"
" result30 += FLOAT4(in30*w);\n"
" result31 += FLOAT4(in31*w);\n"
" result32 += FLOAT4(in32*w);\n"
" result33 += FLOAT4(in33*w);\n"
" \n"
" xy_in0 += cst.input_size;\n"
" }\n"
" int widthSize=min(cst.output_width-idx_w,4);\n"
" /* true */ *xy_out=activate(M4(result00),cst.activation);\n"
" if (widthSize>1) {xy_out[1]=activate(M4(result01),cst.activation); }\n"
" if (widthSize>2) {xy_out[2]=activate(M4(result02),cst.activation); }\n"
" if (widthSize>3) {xy_out[3]=activate(M4(result03),cst.activation); }\n"
" \n"
" int heightSize=min(cst.output_height-idx_h,4);\n"
" if(heightSize>1) {\n"
" /* true */ {xy_out[cst.output_width+0]=activate(M4(result10),cst.activation); }\n"
" if (widthSize>1) {xy_out[cst.output_width+1]=activate(M4(result11),cst.activation); }\n"
" if (widthSize>2) {xy_out[cst.output_width+2]=activate(M4(result12),cst.activation); }\n"
" if (widthSize>3) {xy_out[cst.output_width+3]=activate(M4(result13),cst.activation); }\n"
" }\n"
" if(heightSize>2) {\n"
" /* true */ {xy_out[cst.output_width+cst.output_width+0]=activate(M4(result20),cst.activation); }\n"
" if (widthSize>1) {xy_out[cst.output_width+cst.output_width+1]=activate(M4(result21),cst.activation); }\n"
" if (widthSize>2) {xy_out[cst.output_width+cst.output_width+2]=activate(M4(result22),cst.activation); }\n"
" if (widthSize>3) {xy_out[cst.output_width+cst.output_width+3]=activate(M4(result23),cst.activation); }\n"
" }\n"
" if(heightSize>3) {\n"
" /* true */ {xy_out[cst.output_width+cst.output_width+cst.output_width+0]=activate(M4(result30),cst.activation); }\n"
" if (widthSize>1) {xy_out[cst.output_width+cst.output_width+cst.output_width+1]=activate(M4(result31),cst.activation); }\n"
" if (widthSize>2) {xy_out[cst.output_width+cst.output_width+cst.output_width+2]=activate(M4(result32),cst.activation); }\n"
" if (widthSize>3) {xy_out[cst.output_width+cst.output_width+cst.output_width+3]=activate(M4(result33),cst.activation); }\n"
" }\n"
"}\n"
"kernel void conv1x1_w2c2(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*2 >= cst.output_width || (int)gid.y >= cst.output_height) return;\n"
" int channel_pack=(cst.output_channel+7) >> 3;\n"
" int idx_w=gid.x << 1;\n"
" int idx_h=gid.y;\n"
" int idx_c=(gid.z % channel_pack) << 1;\n"
" int idx_b=gid.z/channel_pack;\n"
" \n"
" if(idx_b >= cst.batch || idx_c >= cst.output_slice) return;\n"
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
" auto xy_in0=in+(int)idx_b*cst.input_slice*cst.input_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_out=out+(int)idx_b*cst.output_slice*cst.output_size+idx_c*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" auto biasValue0=FLOAT4(biasTerms[idx_c]);\n"
" auto biasValue1=FLOAT4(biasTerms[idx_c+1]);\n"
" FLOAT4 result0=biasValue0,result1=biasValue0;\n"
" FLOAT4 result4=biasValue1,result5=biasValue1;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in40=xy_in0[0];\n"
" auto in41=xy_in0[1];\n"
" auto w0=xy_wt[z];\n"
" auto w1=xy_wt[cst.input_slice+z];\n"
" result0 += FLOAT4(in40*w0);\n"
" result1 += FLOAT4(in41*w0);\n"
" result4 += FLOAT4(in40*w1);\n"
" result5 += FLOAT4(in41*w1);\n"
" xy_in0 += cst.input_size;\n"
" }\n"
" int widthSize=min(cst.output_width-idx_w,2);\n"
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
" if (widthSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
" \n"
" int channelSize=min(cst.output_slice-idx_c,2);\n"
" if(channelSize>1) {\n"
" /* true */ {xy_out[cst.output_size+0]=activate(M4(result4),cst.activation); }\n"
" if (widthSize>1) {xy_out[cst.output_size+1]=activate(M4(result5),cst.activation); }\n"
" }\n"
"}\n"
"kernel void conv1x1_w2h2(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*2 >= cst.output_width || (int)gid.y*2 >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n"
" int idx_w=gid.x << 1;\n"
" int idx_h=gid.y << 1;\n"
" int idx_c=gid.z % cst.output_slice;\n"
" int idx_b=gid.z/cst.output_slice;\n"
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
" auto xy_in0=in+(int)idx_b*cst.input_slice*cst.input_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_out=out+(int)idx_b*cst.output_slice*cst.output_size+idx_c*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" auto biasValue=FLOAT4(biasTerms[idx_c]);\n"
" FLOAT4 result0=biasValue,result1=biasValue;\n"
" FLOAT4 result4=biasValue,result5=biasValue;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in40=xy_in0[0];\n"
" auto in41=xy_in0[1];\n"
" auto in44=xy_in0[cst.output_width+0];\n"
" auto in45=xy_in0[cst.output_width+1];\n"
" auto w=xy_wt[z];\n"
" result0 += FLOAT4(in40*w);\n"
" result1 += FLOAT4(in41*w);\n"
" result4 += FLOAT4(in44*w);\n"
" result5 += FLOAT4(in45*w);\n"
" xy_in0 += cst.input_size;\n"
" }\n"
" int widthSize=min(cst.output_width-idx_w,2);\n"
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
" if (widthSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
" \n"
" int heightSize=min(cst.output_height-idx_h,2);\n"
" if(heightSize>1) {\n"
" /* true */ {xy_out[cst.output_width+0]=activate(M4(result4),cst.activation); }\n"
" if (widthSize>1) {xy_out[cst.output_width+1]=activate(M4(result5),cst.activation); }\n"
" }\n"
"}\n"
"kernel void conv1x1_w2h2c2(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*2 >= cst.output_width || (int)gid.y*2 >= cst.output_height) return;\n"
" int channel_pack=(cst.output_channel+7) >> 3;\n"
" int idx_w=gid.x << 1;\n"
" int idx_h=gid.y << 1;\n"
" int idx_c=(gid.z % channel_pack) << 1;\n"
" int idx_b=gid.z/channel_pack;\n"
" if(idx_b >= cst.batch || idx_c >= cst.output_slice) return;\n"
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
" auto xy_in0=in+(int)idx_b*cst.input_slice*cst.input_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_out=out+(int)idx_b*cst.output_slice*cst.output_size+idx_c*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" auto biasValue0=FLOAT4(biasTerms[idx_c]);\n"
" auto biasValue1=FLOAT4(biasTerms[idx_c+1]);\n"
" FLOAT4 result0=biasValue0,result1=biasValue0;\n"
" FLOAT4 result4=biasValue0,result5=biasValue0;\n"
" FLOAT4 result2=biasValue1,result3=biasValue1;\n"
" FLOAT4 result6=biasValue1,result7=biasValue1;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in40=xy_in0[0];\n"
" auto in41=xy_in0[1];\n"
" auto in44=xy_in0[cst.output_width+0];\n"
" auto in45=xy_in0[cst.output_width+1];\n"
" auto w0=xy_wt[z];\n"
" auto w1=xy_wt[cst.input_slice+z];\n"
" result0 += FLOAT4(in40*w0);\n"
" result1 += FLOAT4(in41*w0);\n"
" result4 += FLOAT4(in44*w0);\n"
" result5 += FLOAT4(in45*w0);\n"
" result2 += FLOAT4(in40*w1);\n"
" result3 += FLOAT4(in41*w1);\n"
" result6 += FLOAT4(in44*w1);\n"
" result7 += FLOAT4(in45*w1);\n"
" xy_in0 += cst.input_size;\n"
" }\n"
" int widthSize=min(cst.output_width-idx_w,2);\n"
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
" if (widthSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
" \n"
" int heightSize=min(cst.output_height-idx_h,2);\n"
" if(heightSize>1) {\n"
" /* true */ {xy_out[cst.output_width+0]=activate(M4(result4),cst.activation); }\n"
" if (widthSize>1) {xy_out[cst.output_width+1]=activate(M4(result5),cst.activation); }\n"
" }\n"
" \n"
" int channelSize=min(cst.output_slice-idx_c,2);\n"
" if(channelSize>1) {\n"
" /* true */ xy_out[cst.output_size]=activate(M4(result2),cst.activation);\n"
" if (widthSize>1) {xy_out[cst.output_size+1]=activate(M4(result3),cst.activation); }\n"
" \n"
" if(heightSize>1) {\n"
" /* true */ {xy_out[cst.output_size+cst.output_width+0]=activate(M4(result6),cst.activation); }\n"
" if (widthSize>1) {xy_out[cst.output_size+cst.output_width+1]=activate(M4(result7),cst.activation); }\n"
" }\n"
" }\n"
"}\n"
;
const char* shader_MetalConvolutionGEMM_metal = 
"struct conv_im2col_cst {\n"
" int input_width;\n"
" int input_height;\n"
" int input_size;\n"
" int input_slice;\n"
" int output_width;\n"
" int output_height;\n"
" int output_size;\n"
" int output_slice;\n"
" int batch;\n"
" \n"
" int kernel_x;\n"
" int kernel_y;\n"
" int kernel_size;\n"
" int stride_x;\n"
" int stride_y;\n"
" int pad_x;\n"
" int pad_y;\n"
" int dilation_x;\n"
" int dilation_y;\n"
" conv_activation_type activation;\n"
"};\n"
"kernel void conv_im2col(const device M4 *im [[buffer(0)]],\n"
" device M4 *cols [[buffer(1)]],\n"
" constant conv_im2col_cst& cst [[buffer(2)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" auto z=gid.z % cst.input_slice;\n"
" auto b=gid.z/cst.input_slice;\n"
" if ((int)gid.x<cst.output_width && (int)gid.y<cst.output_height && (int)b<cst.batch) {\n"
" int offset_x=gid.x*cst.stride_x-cst.pad_x;\n"
" int offset_y=gid.y*cst.stride_y-cst.pad_y;\n"
" int index=b*cst.output_size+gid.y*cst.output_width+gid.x;\n"
" int cols_y=index/4;\n"
" int cols_x=index % 4+z*cst.kernel_size*4;\n"
" \n"
" auto xy_cols=cols+cols_y*cst.kernel_size*cst.input_slice*4+cols_x;\n"
" auto xy_im=im+b*cst.input_size*cst.input_slice+z*cst.input_size;\n"
" for (int ky=0,src_y=offset_y; ky<cst.kernel_y; ky++,src_y += cst.dilation_y) {\n"
" for (int kx=0,src_x=offset_x; kx<cst.kernel_x; kx++,src_x += cst.dilation_x) {\n"
" auto pad=src_x<0 || src_y<0 || src_x >= cst.input_width || src_y >= cst.input_height;\n"
" xy_cols[(ky*cst.kernel_x+kx)*4]=pad ? 0 : xy_im[src_y*cst.input_width+src_x];\n"
" }\n"
" }\n"
" }\n"
"}\n"
"kernel void conv_col2im(const device M4 *cols [[buffer(0)]],\n"
" device M4 *im [[buffer(1)]],\n"
" const device M4 *biasTerms [[buffer(2)]],\n"
" constant conv_im2col_cst& cst [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" auto z=gid.z % cst.output_slice;\n"
" auto b=gid.z/cst.output_slice;\n"
" if ((int)gid.x<cst.output_width && (int)gid.y<cst.output_height && (int)b<cst.batch) {\n"
" int index=b*cst.output_size+gid.y*cst.output_width+gid.x;\n"
" auto src_x=index/4;\n"
" auto src_y=index % 4+z*4;\n"
" auto src_y_stride=UP_DIV(cst.output_size*cst.batch,4);\n"
" \n"
" auto v=cols[(int)src_y*src_y_stride+(int)src_x]+biasTerms[(int)z];\n"
" im[(int)gid.z*cst.output_size+(int)gid.y*cst.output_width+(int)gid.x]=activate(v,cst.activation);\n"
" }\n"
"}\n"
"struct matmul4x4_const {\n"
" int output_width;\n"
" int output_height;\n"
" int multi_length;\n"
" int group;\n"
"};\n"
"template <typename IType,typename OType>\n"
"static inline void matmul4x4_template(const device IType *in,\n"
" device OType *out,\n"
" const device IType *kt,\n"
" constant matmul4x4_const &cst,\n"
" uint3 gid) {\n"
" if ((int)gid.x<cst.output_width && (int)gid.y<cst.output_height) {\n"
" auto ky=(int)gid.y+(int)gid.z*cst.output_height;\n"
" auto iy=(int)gid.x+(int)gid.z*cst.output_width;\n"
" auto off_in=in+iy*cst.multi_length;\n"
" auto off_wt=kt+ky*cst.multi_length;\n"
" auto off_out=out+iy+4*(int)gid.y*cst.output_width*cst.group;\n"
" \n"
" FLOAT4 result0=0,result1=0,result2=0,result3=0;\n"
" for (int k=0; k<cst.multi_length; ++k) {\n"
" auto w4x4=off_wt[k];\n"
" auto i4x4=off_in[k];\n"
" result0 += FLOAT4(w4x4*i4x4[0]);\n"
" result1 += FLOAT4(w4x4*i4x4[1]);\n"
" result2 += FLOAT4(w4x4*i4x4[2]);\n"
" result3 += FLOAT4(w4x4*i4x4[3]);\n"
" }\n"
" *off_out=OType(result0); off_out += cst.output_width*cst.group;\n"
" *off_out=OType(result1); off_out += cst.output_width*cst.group;\n"
" *off_out=OType(result2); off_out += cst.output_width*cst.group;\n"
" *off_out=OType(result3);\n"
" }\n"
"}\n"
"kernel void matmul4x4(const device M4x4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" const device M4x4 *kt [[buffer(2)]],\n"
" constant matmul4x4_const &cst [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" matmul4x4_template<M4x4,M4>(in,out,kt,cst,gid);\n"
"}\n"
;
const char* shader_MetalResize_metal = 
"struct resize_shape {\n"
" int input_width;\n"
" int input_height;\n"
" int input_size;\n"
" int output_width;\n"
" int output_height;\n"
" int output_size;\n"
" int sliceMap;\n"
"};\n"
"kernel void resize_nearest(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant resize_shape &c [[buffer(2)]],\n"
" constant float4& s [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= c.output_width || (int)gid.y >= c.output_height || (int)gid.z >= c.sliceMap) return;\n"
" \n"
" float srcX=gid.x*s.x+s.y,srcY=gid.y*s.z+s.w;\n"
" int left=floor(srcX);\n"
" int top=floor(srcY);\n"
" \n"
" auto in_z=in+gid.z*c.input_size;\n"
" auto in_top=in_z+top*c.input_width;\n"
" out[int(gid.z)*c.output_size+int(gid.y)*c.output_width+int(gid.x)]=in_top[left];\n"
"}\n"
"kernel void resize_bilinear(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant resize_shape &c [[buffer(2)]],\n"
" constant float4& s [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= c.output_width || (int)gid.y >= c.output_height || (int)gid.z >= c.sliceMap) return;\n"
" \n"
" float srcX=gid.x*s.x+s.y,srcY=gid.y*s.z+s.w;\n"
" int srcXInt=int(floor(srcX));\n"
" int srcYInt=int(floor(srcY));\n"
" int left=clamp(srcXInt,0,c.input_width-1);\n"
" int right=clamp(srcXInt+1,0,c.input_width-1);\n"
" int top=clamp(srcYInt,0,c.input_height-1);\n"
" int bottom=clamp(srcYInt+1,0,c.input_height-1);\n"
" float x2_factor=srcX-float(srcXInt);\n"
" float y2_factor=srcY-float(srcYInt);\n"
" float x1_factor=1-x2_factor;\n"
" float y1_factor=1-y2_factor;\n"
" \n"
" auto in_z=in+gid.z*c.input_size;\n"
" auto in_top=in_z+top*c.input_width;\n"
" auto in_bottom=in_z+bottom*c.input_width;\n"
" auto tl=float4(in_top[left])*x1_factor*y1_factor;\n"
" auto tr=float4(in_top[right])*x2_factor*y1_factor;\n"
" auto bl=float4(in_bottom[left])*x1_factor*y2_factor;\n"
" auto br=float4(in_bottom[right])*x2_factor*y2_factor;\n"
" out[int(gid.z)*c.output_size+int(gid.y)*c.output_width+int(gid.x)]=M4(tl+tr+bl+br);\n"
"}\n"
"static inline float4 resize_cubic_interpolation(float4 A,float4 B,float4 C,float4 D,float factor) {\n"
" float4 a=(B-C)+0.5f*(B-A)+(D-C)*0.5f;\n"
" float4 b=C-((B-A)+(B-C))-(B+D)*0.5f;\n"
" float4 c=(C-A)*0.5f;\n"
" float4 d=B;\n"
" return ((a*factor+b)*factor+c)*factor+d;\n"
"}\n"
"kernel void resize_cubic(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant resize_shape &c [[buffer(2)]],\n"
" constant float4& s [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= c.output_width || (int)gid.y >= c.output_height || (int)gid.z >= c.sliceMap) return;\n"
" float x=gid.x*s.x+s.y,y=gid.y*s.z+s.w;\n"
" \n"
" float x_factor=x-floor(x);\n"
" float y_factor=y-floor(y);\n"
" \n"
" int4 xp=int4(int(x)-1,int(x)+0,int(x)+1,int(x)+2);\n"
" xp=clamp(xp,0,c.input_width-1);\n"
" \n"
" int4 yp=int4(int(y)-1,int(y)+0,int(y)+1,int(y)+2);\n"
" yp=clamp(yp,0,c.input_height-1);\n"
" \n"
" auto in_z=in+gid.z*c.input_size;\n"
" float4x4 ABCD;\n"
" for (int i=0; i<4; i++) {\n"
" auto in_y=in_z+yp[i]*c.input_width;\n"
" float4 A=float4(in_y[xp[0]]);\n"
" float4 B=float4(in_y[xp[1]]);\n"
" float4 C=float4(in_y[xp[2]]);\n"
" float4 D=float4(in_y[xp[3]]);\n"
" ABCD[i]=resize_cubic_interpolation(A,B,C,D,x_factor);\n"
" }\n"
" \n"
" auto val=M4(resize_cubic_interpolation(ABCD[0],ABCD[1],ABCD[2],ABCD[3],y_factor));\n"
" out[int(gid.z)*c.output_size+int(gid.y)*c.output_width+int(gid.x)]=val;\n"
"}\n"
;
const char* shader_MetalPReLU_metal = 
"struct prelu_shape {\n"
" int size;\n"
" int slice;\n"
" int batch;\n"
"};\n"
"kernel void prelu(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant float &slope [[buffer(2)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" auto v4=in[int(gid)];\n"
" out[int(gid)]=select(v4,M4(slope)*v4,signbit(v4));\n"
"}\n"
"kernel void prelu_slopes(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" const device float4 *slope [[buffer(2)]],\n"
" constant prelu_shape& s [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) { // size,slice,batch\n"
" if ((int)gid.x >= s.size || (int)gid.y >= s.slice) return;\n"
" \n"
" int z=gid.z*s.slice+gid.y;\n"
" auto v4=in[z*s.size+int(gid.x)];\n"
" out[z*s.size+int(gid.x)]=select(v4,M4(slope[int(gid.y)])*v4,signbit(v4));\n"
"}\n"
;
const char* shader_MetalDefine_metal = 
"using namespace metal;\n"
"// –––––––––––––––––––––––––––––––––––––––––––––––––––\n"
"// Macro\n"
"// –––––––––––––––––––––––––––––––––––––––––––––––––––\n"
"#define UP_DIV(x,y) ( ((x)+(y)-1)/(y) )\n"
"#define ROUND_UP(x,y) ( ((x)+(y)-1)/(y)*(y) )\n"
"// whether store with float32\n"
"#define MNN_METAL_FULL_PRECISION 0 // should edit in .h too\n"
"// whether computer with float32 when store with float16\n"
"#define MNN_METAL_FLOAT32_COMPUTER 1 //\n"
"#if MNN_METAL_FULL_PRECISION\n"
"typedef float M;\n"
"typedef float2 M2;\n"
"typedef float3 M3;\n"
"typedef float4 M4;\n"
"typedef float2x2 M2x2;\n"
"typedef float2x3 M2x3;\n"
"typedef float2x4 M2x4;\n"
"typedef float3x2 M3x2;\n"
"typedef float3x3 M3x3;\n"
"typedef float3x4 M3x4;\n"
"typedef float4x2 M4x2;\n"
"typedef float4x3 M4x3;\n"
"typedef float4x4 M4x4;\n"
"#else\n"
"typedef half M;\n"
"typedef half2 M2;\n"
"typedef half3 M3;\n"
"typedef half4 M4;\n"
"typedef half2x2 M2x2;\n"
"typedef half2x3 M2x3;\n"
"typedef half2x4 M2x4;\n"
"typedef half3x2 M3x2;\n"
"typedef half3x3 M3x3;\n"
"typedef half3x4 M3x4;\n"
"typedef half4x2 M4x2;\n"
"typedef half4x3 M4x3;\n"
"typedef half4x4 M4x4;\n"
"#endif\n"
"#if MNN_METAL_FLOAT32_COMPUTER\n"
"typedef float FLOAT;\n"
"typedef float2 FLOAT2;\n"
"typedef float3 FLOAT3;\n"
"typedef float4 FLOAT4;\n"
"typedef float2x2 FLOAT2x2;\n"
"typedef float2x3 FLOAT2x3;\n"
"typedef float2x4 FLOAT2x4;\n"
"typedef float3x2 FLOAT3x2;\n"
"typedef float3x3 FLOAT3x3;\n"
"typedef float3x4 FLOAT3x4;\n"
"typedef float4x2 FLOAT4x2;\n"
"typedef float4x3 FLOAT4x3;\n"
"typedef float4x4 FLOAT4x4;\n"
"#else\n"
"typedef half FLOAT;\n"
"typedef half2 FLOAT2;\n"
"typedef half3 FLOAT3;\n"
"typedef half4 FLOAT4;\n"
"typedef half2x2 FLOAT2x2;\n"
"typedef half2x3 FLOAT2x3;\n"
"typedef half2x4 FLOAT2x4;\n"
"typedef half3x2 FLOAT3x2;\n"
"typedef half3x3 FLOAT3x3;\n"
"typedef half3x4 FLOAT3x4;\n"
"typedef half4x2 FLOAT4x2;\n"
"typedef half4x3 FLOAT4x3;\n"
"typedef half4x4 FLOAT4x4;\n"
"#endif\n"
"namespace MNN {\n"
" \n"
" // –––––––––––––––––––––––––––––––––––––––––––––––––––\n"
" // Number Limit\n"
" // –––––––––––––––––––––––––––––––––––––––––––––––––––\n"
"#define INT8_MAX 127\n"
"#define INT8_MIN -128\n"
"#define INT16_MAX 32767\n"
"#define INT16_MIN -32768\n"
"#define INT32_MAX 2147483647\n"
"#define INT32_MIN -2147483648\n"
"#define UINT8_MAX 255\n"
"#define UINT16_MAX 65535\n"
"#define UINT32_MAX 4294967295U\n"
" \n"
" template<typename T> struct num_limits {\n"
" static int max() { return 0; };\n"
" static int min() { return 0; };\n"
" };\n"
" template<> struct num_limits<char> {\n"
" static int max() { return INT8_MAX; };\n"
" static int min() { return INT8_MIN; };\n"
" };\n"
" template<> struct num_limits<uchar> {\n"
" static int max() { return UINT8_MAX; };\n"
" static int min() { return 0; };\n"
" };\n"
" template<> struct num_limits<short> {\n"
" static int max() { return INT16_MAX; };\n"
" static int min() { return INT16_MIN; };\n"
" };\n"
" template<> struct num_limits<ushort> {\n"
" static int max() { return UINT16_MAX; };\n"
" static int min() { return 0; };\n"
" };\n"
" template<> struct num_limits<int> {\n"
" static int max() { return INT32_MAX; };\n"
" static int min() { return INT32_MIN; };\n"
" };\n"
" template<> struct num_limits<uint> {\n"
" static int max() { return UINT32_MAX; };\n"
" static int min() { return 0; };\n"
" };\n"
" \n"
" // –––––––––––––––––––––––––––––––––––––––––––––––––––\n"
" // Function\n"
" // –––––––––––––––––––––––––––––––––––––––––––––––––––\n"
" inline int dot(int4 i4,int4 w4) {\n"
" return i4[0]*w4[0]+i4[1]*w4[1]+i4[2]*w4[2]+i4[3]*w4[3];\n"
" }\n"
" \n"
" template <typename T>\n"
" inline T saturate_round_x2_high_mul(T a,int b) {\n"
" return mulhi(a,b)*2;\n"
" }\n"
" \n"
" template <typename T>\n"
" inline T round_divide_by_pot(T x,int exponent) {\n"
" int mask=(1 << exponent)-1;\n"
" T remainder=x & mask;\n"
" T threshold=(mask >> 1)+T(x<0);\n"
" return (x >> exponent)+T(remainder>threshold);\n"
" }\n"
" \n"
" // –––––––––––––––––––––––––––––––––––––––––––––––––––\n"
" // Typedef\n"
" // –––––––––––––––––––––––––––––––––––––––––––––––––––\n"
" \n"
" typedef struct short4x4 {\n"
" private:\n"
" short4 v[4];\n"
" public:\n"
" short4x4(short4 a) {\n"
" v[0]=a; v[1]=a; v[2]=a; v[3]=a;\n"
" }\n"
" short4x4(short4 a,short4 b,short4 c,short4 d) {\n"
" v[0]=a; v[1]=b; v[2]=c; v[3]=d;\n"
" }\n"
" \n"
" inline thread short4& operator[] (const int index) {\n"
" return v[index];\n"
" }\n"
" inline device short4& operator[] (const int index) device {\n"
" return v[index];\n"
" }\n"
" inline threadgroup short4& operator[] (const int index) threadgroup {\n"
" return v[index];\n"
" }\n"
" \n"
" inline const thread short4& operator[] (const int index) const {\n"
" return v[index];\n"
" }\n"
" inline const device short4& operator[] (const int index) const device {\n"
" return v[index];\n"
" }\n"
" inline const threadgroup short4& operator[] (const int index) const threadgroup {\n"
" return v[index];\n"
" }\n"
" \n"
" inline explicit operator half4x4() const {\n"
" return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n"
" }\n"
" inline explicit operator half4x4() const device{\n"
" return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n"
" }\n"
" inline explicit operator half4x4() const threadgroup {\n"
" return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n"
" }\n"
" \n"
" inline explicit operator float4x4() const {\n"
" return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n"
" }\n"
" inline explicit operator float4x4() const device {\n"
" return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n"
" }\n"
" inline explicit operator float4x4() const threadgroup {\n"
" return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n"
" }\n"
" } short4x4;\n"
" \n"
" typedef struct char4x4 {\n"
" private:\n"
" char4 v[4];\n"
" public:\n"
" char4x4(char4 a) {\n"
" v[0]=a; v[1]=a; v[2]=a; v[3]=a;\n"
" }\n"
" char4x4(char4 a,char4 b,char4 c,char4 d) {\n"
" v[0]=a; v[1]=b; v[2]=c; v[3]=d;\n"
" }\n"
" \n"
" inline thread char4& operator[] (const int index) {\n"
" return v[index];\n"
" }\n"
" inline device char4& operator[] (const int index) device {\n"
" return v[index];\n"
" }\n"
" inline threadgroup char4& operator[] (const int index) threadgroup {\n"
" return v[index];\n"
" }\n"
" \n"
" inline const thread char4& operator[] (const int index) const {\n"
" return v[index];\n"
" }\n"
" inline const device char4& operator[] (const int index) const device {\n"
" return v[index];\n"
" }\n"
" inline const threadgroup char4& operator[] (const int index) const threadgroup {\n"
" return v[index];\n"
" }\n"
" \n"
" inline explicit operator half4x4() const {\n"
" return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n"
" }\n"
" inline explicit operator half4x4() const device {\n"
" return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n"
" }\n"
" inline explicit operator half4x4() const threadgroup {\n"
" return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n"
" }\n"
" \n"
" inline explicit operator float4x4() const {\n"
" return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n"
" }\n"
" inline explicit operator float4x4() const device {\n"
" return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n"
" }\n"
" inline explicit operator float4x4() const threadgroup {\n"
" return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n"
" }\n"
" } char4x4;\n"
"}\n"
;
const char* shader_MetalUnary_metal = 
"struct unary_shape {\n"
" int width;\n"
" int height;\n"
" int size;\n"
"};\n"
"static inline float4 neg(float4 V) { return -V; }\n"
"static inline float4 square(float4 V) { return V*V; }\n"
"static inline float4 expm1(float4 V) {return exp(V)-1;}\n"
"static inline float4 reciprocal(float4 V) {return 1.0/(V);}\n"
"static inline float4 sigmoid(float4 V) {return 1.f/(1.f+exp(-V));}\n"
"static inline float4 log1p(float4 V) {return log(1.f+V);}\n"
"static inline float4 hardswish(float4 V) {\n"
" return (float4)(1.0/6.0)*(V*min(max(V+(float4)3,0),(float4)6));\n"
"}\n"
"static inline float4 gelu(float4 V) {\n"
" float4 temp=(float4)0.044715f*V*V*V;\n"
" temp=0.79788458f*(temp+V);\n"
" float4 result=((float4)1.0f+tanh(temp))*V*(float4)0.5f;\n"
" return result;\n"
"}\n"
"#define define_op(op) "
"kernel void unary_##op##_x4(const device M4 *in [[buffer(0)]],"
" device M4 *out [[buffer(1)]],"
" device unary_shape& s [[buffer(2)]],"
" uint3 gid [[thread_position_in_grid]]) { "
" if (gid.x<(uint)s.width) { "
" int off=gid.z*s.size+gid.y*s.width+gid.x; "
" out[off]=(M4)(op((float4)(in[off]))); "
" } "
"}\n"
"define_op(abs);\n"
"define_op(floor);\n"
"define_op(ceil);\n"
"define_op(expm1);\n"
"define_op(square);\n"
"define_op(sqrt);\n"
"define_op(rsqrt);\n"
"define_op(exp);\n"
"define_op(log);\n"
"define_op(sin);\n"
"define_op(cos);\n"
"define_op(tan);\n"
"define_op(asin);\n"
"define_op(acos);\n"
"define_op(atan);\n"
"define_op(neg);\n"
"define_op(reciprocal)\n"
"define_op(tanh);\n"
"define_op(sigmoid);\n"
"define_op(sign);\n"
"define_op(log1p);\n"
"define_op(cosh);\n"
"define_op(sinh);\n"
"define_op(acosh);\n"
"define_op(asinh);\n"
"define_op(atanh);\n"
"define_op(round);\n"
"define_op(hardswish);\n"
"define_op(gelu);\n"
;
const char* shader_MetalBinary_metal = 
"struct binary_op_shape {\n"
" int i0stride;\n"
" int i1stride;\n"
" int output_data_count;\n"
" int activationType;\n"
"};\n"
"#define define_op(op) "
"kernel void binary_##op##_x1(const device M *in0 [[buffer(0)]],"
" const device M *in1 [[buffer(1)]],"
" device M *out [[buffer(2)]],"
" constant binary_op_shape& s [[buffer(3)]],"
" uint gid [[thread_position_in_grid]]) {"
" if ((int)gid >= s.output_data_count) return;"
" auto V0=in0[s.i0stride*int(gid)];"
" auto V1=in1[s.i1stride*int(gid)];"
" auto val=op(V0,V1);"
" if(s.activationType == 1) {"
" val=(val<(M)0 ? (M)0 : val);"
" }"
" out[int(gid)]=val;"
"}\n"
"static inline M add(M V1,M V2) {\n"
" return V1+V2;\n"
"}\n"
"static inline M sub(M V1,M V2) {\n"
" return V1-V2;\n"
"}\n"
"static inline M mul(M V1,M V2) {\n"
" return V1*V2;\n"
"}\n"
"static inline M div(M V1,M V2) {\n"
" return V2 == 0 ? 0 : V1/V2;\n"
"}\n"
"static inline M squared_diff(M V1,M V2) {\n"
" return (V1-V2)*(V1-V2);\n"
"}\n"
"static inline M mod(M V1,M V2) {\n"
" return fmod(V1,V2);\n"
"}\n"
"static inline M floormod(M x,M y) {\n"
" return x-floor(x/y)*y;\n"
"}\n"
"define_op(add)\n"
"define_op(sub)\n"
"define_op(mul)\n"
"define_op(div)\n"
"define_op(max)\n"
"define_op(min)\n"
"define_op(pow)\n"
"define_op(floormod)\n"
"define_op(mod)\n"
"define_op(squared_diff)\n"
;
const char* shader_MetalEltwise_metal = 
"kernel void eltwise_prod(device const M *in0 [[buffer(0)]],\n"
" device const M *in1 [[buffer(1)]],\n"
" device M *out [[buffer(2)]],\n"
" constant int4& shape [[buffer(3)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" if ((int)gid<shape.x) {\n"
" out[(int)gid]=in0[(int)gid]*in1[(int)gid];\n"
" }\n"
"}\n"
"kernel void eltwise_max(device const M *in0 [[buffer(0)]],\n"
" device const M *in1 [[buffer(1)]],\n"
" device M *out [[buffer(2)]],\n"
" constant int4& shape [[buffer(3)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" if ((int)gid<shape.x) {\n"
" out[(int)gid]=max(in0[(int)gid],in1[(int)gid]);\n"
" }\n"
"}\n"
"kernel void eltwise_add(device const M *in0 [[buffer(0)]],\n"
" device const M *in1 [[buffer(1)]],\n"
" device M *out [[buffer(2)]],\n"
" constant int4& shape [[buffer(3)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" if ((int)gid<shape.x) {\n"
" out[(int)gid]=in0[(int)gid]+in1[(int)gid];\n"
" }\n"
"}\n"
;
